res.setStatus(HttpServletResponse.SC_NOT_FOUND)
- UIUtils.basicSparkPage(msg, "Not Found").foreach { n =>
+ UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n =>
res.getWriter().write(n.toString)
}
return
@@ -124,7 +124,7 @@ class HistoryServer(
attachHandler(ApiRootResource.getServletHandler(this))
- attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
+ addStaticHandler(SparkUI.STATIC_RESOURCE_DIR)
val contextHandler = new ServletContextHandler
contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX)
@@ -150,14 +150,17 @@ class HistoryServer(
ui: SparkUI,
completed: Boolean) {
assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs")
- ui.getHandlers.foreach(attachHandler)
- addFilters(ui.getHandlers, conf)
+ handlers.synchronized {
+ ui.getHandlers.foreach(attachHandler)
+ }
}
/** Detach a reconstructed UI from this server. Only valid after bind(). */
override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = {
assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs")
- ui.getHandlers.foreach(detachHandler)
+ handlers.synchronized {
+ ui.getHandlers.foreach(detachHandler)
+ }
provider.onUIDetached(appId, attemptId, ui)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala
index efdbf672bb52f..25ba9edb9e014 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala
@@ -49,4 +49,19 @@ private[spark] object config {
.intConf
.createWithDefault(18080)
+ val FAST_IN_PROGRESS_PARSING =
+ ConfigBuilder("spark.history.fs.inProgressOptimization.enabled")
+ .doc("Enable optimized handling of in-progress logs. This option may leave finished " +
+ "applications that fail to rename their event logs listed as in-progress.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val END_EVENT_REPARSE_CHUNK_SIZE =
+ ConfigBuilder("spark.history.fs.endEventReparseChunkSize")
+ .doc("How many bytes to parse at the end of log files looking for the end event. " +
+ "This is used to speed up generation of application listings by skipping unnecessary " +
+ "parts of event log files. It can be disabled by setting this config to 0.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("1m")
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f699c75085fe1..fad4e46dc035d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
.getOrElse(state.completedApps.find(_.id == appId).orNull)
if (app == null) {
val msg =
No running application with ID {appId}
- return UIUtils.basicSparkPage(msg, "Not Found")
+ return UIUtils.basicSparkPage(request, msg, "Not Found")
}
val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs")
@@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
}
;
- UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
+ UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name)
}
private def executorRow(executor: ExecutorDesc): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index c629937606b51..b8afe203fbfa2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
;
- UIUtils.basicSparkPage(content, "Spark Master at " + state.uri)
+ UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri)
}
private def workerRow(worker: WorkerInfo): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 35b7ddd46e4db..e87b2240564bd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -43,7 +43,7 @@ class MasterWebUI(
val masterPage = new MasterPage(this)
attachPage(new ApplicationPage(this))
attachPage(masterPage)
- attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
+ addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)
attachHandler(createRedirectHandler(
"/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST")))
attachHandler(createRedirectHandler(
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
index e88195d95f270..3d99d085408c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala
@@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer(
new HttpConnectionFactory())
connector.setHost(host)
connector.setPort(startPort)
+ connector.setReuseAddress(!Utils.isWindows)
server.addConnector(connector)
val mainHandler = new ServletContextHandler
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
index 5151df00476f9..ab8d8d96a9b08 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
@@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging
*
* Also, each HadoopDelegationTokenProvider is controlled by
* spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to
- * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be
+ * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be
* enabled/disabled by the configuration spark.security.credentials.hive.enabled.
*
* @param sparkConf Spark configuration
@@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager(
// Maintain all the registered delegation token providers
private val delegationTokenProviders = getDelegationTokenProviders
- logDebug(s"Using the following delegation token providers: " +
+ logDebug("Using the following builtin delegation token providers: " +
s"${delegationTokenProviders.keys.mkString(", ")}.")
/** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
index ece5ce79c650d..7249eb85ac7c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
@@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.KEYTAB
import org.apache.spark.util.Utils
-private[security] class HiveDelegationTokenProvider
+private[spark] class HiveDelegationTokenProvider
extends HadoopDelegationTokenProvider with Logging {
override def serviceName: String = "hive"
@@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider
val currentUser = UserGroupInformation.getCurrentUser()
val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser)
- // For some reason the Scala-generated anonymous class ends up causing an
- // UndeclaredThrowableException, even if you annotate the method with @throws.
- try {
+ // For some reason the Scala-generated anonymous class ends up causing an
+ // UndeclaredThrowableException, even if you annotate the method with @throws.
+ try {
realUser.doAs(new PrivilegedExceptionAction[T]() {
override def run(): T = fn
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 58a181128eb4d..a6d13d12fc28d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -225,7 +225,7 @@ private[deploy] class DriverRunner(
// check if attempting another run
keepTrying = supervise && exitCode != 0 && !killed
if (keepTrying) {
- if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) {
+ if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) {
waitSeconds = 1
}
logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index b19c9904d5982..8d6a2b80ef5f2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
+import org.apache.spark.util._
/**
* Utility object for launching driver programs such that they share fate with the Worker process.
@@ -79,16 +79,21 @@ object DriverWrapper extends Logging {
val secMgr = new SecurityManager(sparkConf)
val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf)
- val Seq(packagesExclusions, packages, repositories, ivyRepoPath) =
- Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy")
- .map(sys.props.get(_).orNull)
+ val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) =
+ Seq(
+ "spark.jars.excludes",
+ "spark.jars.packages",
+ "spark.jars.repositories",
+ "spark.jars.ivy",
+ "spark.jars.ivySettings"
+ ).map(sys.props.get(_).orNull)
val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions,
- packages, repositories, ivyRepoPath)
+ packages, repositories, ivyRepoPath, Option(ivySettingsPath))
val jars = {
val jarsProp = sys.props.get("spark.jars").orNull
if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
- SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates)
+ DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates)
} else {
jarsProp
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index d4d8521cc8204..dc6a3076a5113 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import com.google.common.io.Files
import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
+import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
@@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner(
private def fetchAndRunExecutor() {
try {
// Launch the process
- val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf),
+ val subsOpts = appDesc.command.javaOpts.map {
+ Utils.substituteAppNExecIds(_, appId, execId.toString)
+ }
+ val subsCommand = appDesc.command.copy(javaOpts = subsOpts)
+ val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf),
memory, sparkHome.getAbsolutePath, substituteVariables)
val command = builder.command()
val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 563b84934f264..ee1ca0bba5749 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -23,6 +23,7 @@ import java.text.SimpleDateFormat
import java.util.{Date, Locale, UUID}
import java.util.concurrent._
import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
+import java.util.function.Supplier
import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap}
import scala.concurrent.ExecutionContext
@@ -49,7 +50,8 @@ private[deploy] class Worker(
endpointName: String,
workDirPath: String = null,
val conf: SparkConf,
- val securityMgr: SecurityManager)
+ val securityMgr: SecurityManager,
+ externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null)
extends ThreadSafeRpcEndpoint with Logging {
private val host = rpcEnv.address.host
@@ -97,6 +99,10 @@ private[deploy] class Worker(
private val APP_DATA_RETENTION_SECONDS =
conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600)
+ // Whether or not cleanup the non-shuffle files on executor exits.
+ private val CLEANUP_NON_SHUFFLE_FILES_ENABLED =
+ conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true)
+
private val testing: Boolean = sys.props.contains("spark.testing")
private var master: Option[RpcEndpointRef] = None
@@ -142,7 +148,11 @@ private[deploy] class Worker(
WorkerWebUI.DEFAULT_RETAINED_DRIVERS)
// The shuffle service is not actually started unless configured.
- private val shuffleService = new ExternalShuffleService(conf, securityMgr)
+ private val shuffleService = if (externalShuffleServiceSupplier != null) {
+ externalShuffleServiceSupplier.get()
+ } else {
+ new ExternalShuffleService(conf, securityMgr)
+ }
private val publicAddress = {
val envVar = conf.getenv("SPARK_PUBLIC_DNS")
@@ -732,6 +742,9 @@ private[deploy] class Worker(
trimFinishedExecutorsIfNecessary()
coresUsed -= executor.cores
memoryUsed -= executor.memory
+ if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) {
+ shuffleService.executorRemoved(executorStateChanged.execId.toString, appId)
+ }
case None =>
logInfo("Unknown Executor " + fullId + " finished with state " + state +
message.map(" message " + _).getOrElse("") +
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 2f5a5642d3cab..4fca9342c0378 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
- UIUtils.basicSparkPage(content, logType + " log page for " + pageName)
+ UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName)
}
/** Get the part of the log files given the offset and desired length of bytes */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
index 8b98ae56fc108..aa4e28d213e2b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala
@@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
}
;
- UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format(
+ UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format(
workerState.host, workerState.port))
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index db696b04384bd..ea67b7434a769 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -47,7 +47,7 @@ class WorkerWebUI(
val logPage = new LogPage(this)
attachPage(logPage)
attachPage(new WorkerPage(this))
- attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"))
+ addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE)
attachHandler(createServletHandler("/log",
(request: HttpServletRequest) => logPage.renderLog(request),
worker.securityMgr,
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 9b62e4b1b7150..48d3630abd1f9 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
driverConf.set(key, value)
}
}
- if (driverConf.contains("spark.yarn.credentials.file")) {
- logInfo("Will periodically update credentials from: " +
- driverConf.get("spark.yarn.credentials.file"))
- Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
- .getMethod("startCredentialUpdater", classOf[SparkConf])
- .invoke(null, driverConf)
- }
cfg.hadoopDelegationCreds.foreach { tokens =>
SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf)
@@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
env.rpcEnv.awaitTermination()
- if (driverConf.contains("spark.yarn.credentials.file")) {
- Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil")
- .getMethod("stopCredentialUpdater")
- .invoke(null)
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 2c3a8ef74800b..b1856ff0f3247 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription}
@@ -141,8 +142,7 @@ private[spark] class Executor(
conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20),
RpcUtils.maxMessageSizeBytes(conf))
- // Limit of bytes for total size of results (default is 1GB)
- private val maxResultSize = Utils.getMaxResultSize(conf)
+ private val maxResultSize = conf.get(MAX_RESULT_SIZE)
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
@@ -287,6 +287,28 @@ private[spark] class Executor(
notifyAll()
}
+ /**
+ * Utility function to:
+ * 1. Report executor runtime and JVM gc time if possible
+ * 2. Collect accumulator updates
+ * 3. Set the finished flag to true and clear current thread's interrupt status
+ */
+ private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = {
+ // Report executor runtime and JVM gc time
+ Option(task).foreach(t => {
+ t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime)
+ t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
+ })
+
+ // Collect latest accumulator values to report back to the driver
+ val accums: Seq[AccumulatorV2[_, _]] =
+ Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty)
+ val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
+
+ setTaskFinishedAndClearInterruptStatus()
+ (accums, accUpdates)
+ }
+
override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
@@ -300,7 +322,7 @@ private[spark] class Executor(
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
- var taskStart: Long = 0
+ var taskStartTime: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
@@ -336,7 +358,7 @@ private[spark] class Executor(
}
// Run the actual task and measure its runtime.
- taskStart = System.currentTimeMillis()
+ taskStartTime = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
@@ -396,11 +418,11 @@ private[spark] class Executor(
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
task.metrics.setExecutorDeserializeTime(
- (taskStart - deserializeStartTime) + task.executorDeserializeTime)
+ (taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
task.metrics.setExecutorDeserializeCpuTime(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
- task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
+ task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
@@ -480,6 +502,22 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
+ case t: TaskKilledException =>
+ logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
+
+ val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
+ val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums))
+ execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
+
+ case _: InterruptedException | NonFatal(_) if
+ task != null && task.reasonIfKilled.isDefined =>
+ val killReason = task.reasonIfKilled.getOrElse("unknown reason")
+ logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
+
+ val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
+ val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums))
+ execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
+
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
@@ -494,19 +532,6 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
- case t: TaskKilledException =>
- logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
- setTaskFinishedAndClearInterruptStatus()
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
-
- case _: InterruptedException | NonFatal(_) if
- task != null && task.reasonIfKilled.isDefined =>
- val killReason = task.reasonIfKilled.getOrElse("unknown reason")
- logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
- setTaskFinishedAndClearInterruptStatus()
- execBackend.statusUpdate(
- taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
-
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskCommitDeniedReason
setTaskFinishedAndClearInterruptStatus()
@@ -524,17 +549,7 @@ private[spark] class Executor(
// the task failure would not be ignored if the shutdown happened because of premption,
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
- // Collect latest accumulator values to report back to the driver
- val accums: Seq[AccumulatorV2[_, _]] =
- if (task != null) {
- task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
- task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
- task.collectAccumulatorUpdates(taskFailed = true)
- } else {
- Seq.empty
- }
-
- val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
+ val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
val serializedTaskEndReason = {
try {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index b0cd7110a3b47..f27aca03773a9 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException
import scala.util.matching.Regex
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
+import org.apache.spark.util.Utils
private object ConfigHelpers {
@@ -45,7 +46,7 @@ private object ConfigHelpers {
}
def stringToSeq[T](str: String, converter: String => T): Seq[T] = {
- str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter)
+ Utils.stringToSeq(str).map(converter)
}
def seqToString[T](v: Seq[T], stringConverter: T => String): String = {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index bbfcfbaa7363c..a54b091a64d50 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -126,6 +126,10 @@ package object config {
private[spark] val DYN_ALLOCATION_MAX_EXECUTORS =
ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue)
+ private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO =
+ ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio")
+ .doubleConf.createWithDefault(1.0)
+
private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("3s")
@@ -301,6 +305,12 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles")
+ .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " +
+ "encountering missing files and the contents that have been read will still be returned.")
+ .booleanConf
+ .createWithDefault(false)
+
private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext")
.stringConf
.createOptional
@@ -332,7 +342,7 @@ package object config {
"a property key or value, the value is redacted from the environment UI and various logs " +
"like YARN and event logs.")
.regexConf
- .createWithDefault("(?i)secret|password|url|user|username".r)
+ .createWithDefault("(?i)secret|password".r)
private[spark] val STRING_REDACTION_PATTERN =
ConfigBuilder("spark.redaction.string.regex")
@@ -342,6 +352,11 @@ package object config {
.regexConf
.createOptional
+ private[spark] val AUTH_SECRET_BIT_LENGTH =
+ ConfigBuilder("spark.authenticate.secretBitLength")
+ .intConf
+ .createWithDefault(256)
+
private[spark] val NETWORK_AUTH_ENABLED =
ConfigBuilder("spark.authenticate")
.booleanConf
@@ -520,4 +535,21 @@ package object config {
.checkValue(v => v > 0, "The threshold should be positive.")
.createWithDefault(10000000)
+ private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize")
+ .doc("Size limit for results.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("1g")
+
+ private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO =
+ ConfigBuilder("spark.security.credentials.renewalRatio")
+ .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.")
+ .doubleConf
+ .createWithDefault(0.75d)
+
+ private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT =
+ ConfigBuilder("spark.security.credentials.retryWait")
+ .doc("How long to wait before retrying to fetch new credentials after a failure.")
+ .timeConf(TimeUnit.SECONDS)
+ .createWithDefaultString("1h")
+
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index 6d0059b6a0272..e6e9c9e328853 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -20,6 +20,7 @@ package org.apache.spark.internal.io
import org.apache.hadoop.fs._
import org.apache.hadoop.mapreduce._
+import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
@@ -132,7 +133,7 @@ abstract class FileCommitProtocol {
}
-object FileCommitProtocol {
+object FileCommitProtocol extends Logging {
class TaskCommitMessage(val obj: Any) extends Serializable
object EmptyTaskCommitMessage extends TaskCommitMessage(null)
@@ -145,15 +146,23 @@ object FileCommitProtocol {
jobId: String,
outputPath: String,
dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = {
+
+ logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" +
+ s" dynamic=$dynamicPartitionOverwrite")
val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]]
// First try the constructor with arguments (jobId: String, outputPath: String,
// dynamicPartitionOverwrite: Boolean).
// If that doesn't exist, try the one with (jobId: string, outputPath: String).
try {
val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean])
+ logDebug("Using (String, String, Boolean) constructor")
ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean])
} catch {
case _: NoSuchMethodException =>
+ logDebug("Falling back to (String, String) constructor")
+ require(!dynamicPartitionOverwrite,
+ "Dynamic Partition Overwrite is enabled but" +
+ s" the committer ${className} does not have the appropriate constructor")
val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String])
ctor.newInstance(jobId, outputPath)
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
index ddbd624b380d4..af0aa41518766 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
@@ -31,6 +31,8 @@ class HadoopMapRedCommitProtocol(jobId: String, path: String)
override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = {
val config = context.getConfiguration.asInstanceOf[JobConf]
- config.getOutputCommitter
+ val committer = config.getOutputCommitter
+ logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}")
+ committer
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index 6d20ef1f98a3c..3e60c50ada59b 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol(
logDebug(s"Clean up default partition directories for overwriting: $partitionPaths")
for (part <- partitionPaths) {
val finalPartPath = new Path(path, part)
- fs.delete(finalPartPath, true)
+ if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) {
+ // According to the official hadoop FileSystem API spec, delete op should assume
+ // the destination is no longer present regardless of return value, thus we do not
+ // need to double check if finalPartPath exists before rename.
+ // Also in our case, based on the spec, delete returns false only when finalPartPath
+ // does not exist. When this happens, we need to take action if parent of finalPartPath
+ // also does not exist(e.g. the scenario described on SPARK-23815), because
+ // FileSystem API spec on rename op says the rename dest(finalPartPath) must have
+ // a parent that exists, otherwise we may get unexpected result on the rename.
+ fs.mkdirs(finalPartPath.getParent)
+ }
fs.rename(new Path(stagingDir, part), finalPartPath)
}
}
diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
index aaae33ca4e6f3..1b049b786023a 100644
--- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
+++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala
@@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend {
}
def setAppId(appId: String): Unit = {
- if (connection != null) {
+ if (connection != null && isConnected) {
connection.send(new SetAppId(appId))
}
}
def setState(state: SparkAppHandle.State): Unit = {
- if (connection != null && lastState != state) {
+ if (connection != null && isConnected && lastState != state) {
connection.send(new SetState(state))
lastState = state
}
@@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend {
override def close(): Unit = {
try {
+ _isConnected = false
super.close()
} finally {
onDisconnected()
- _isConnected = false
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index c9ed12f4e1bd4..ba9dae4ad48ec 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -90,12 +90,12 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%. We also cap the estimation in the end.
if (results.size == 0) {
- numPartsToTry = partsScanned * 4
+ numPartsToTry = partsScanned * 4L
} else {
// the left side of max is >=1 whenever partsScanned >= 2
numPartsToTry = Math.max(1,
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
- numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 10451a324b0f4..94e7d0b38cba3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
numCreated += 1
}
}
- tries = 0
// if we don't have enough partition groups, create duplicates
while (numCreated < targetLen) {
- val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries)
- tries += 1
+ // Copy the preferred location from a random input partition.
+ // This helps in avoiding skew when the input partitions are clustered by preferred location.
+ val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(
+ rnd.nextInt(partitionLocs.partsWithLocs.length))
val pgroup = new PartitionGroup(Some(nxt_replica))
groupArr += pgroup
groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup
addPartToPGroup(nxt_part, pgroup)
numCreated += 1
- if (tries >= partitionLocs.partsWithLocs.length) tries = 0
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 2480559a41b7a..44895abc7bd4d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rdd
-import java.io.IOException
+import java.io.{FileNotFoundException, IOException}
import java.text.SimpleDateFormat
import java.util.{Date, Locale}
@@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred._
import org.apache.hadoop.mapred.lib.CombineFileSplit
import org.apache.hadoop.mapreduce.TaskType
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
@@ -134,6 +135,8 @@ class HadoopRDD[K, V](
private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES)
+ private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES)
+
private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS)
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
@@ -197,17 +200,24 @@ class HadoopRDD[K, V](
val jobConf = getJobConf()
// add the credentials here as this can be called before SparkContext initialized
SparkHadoopUtil.get.addCredentials(jobConf)
- val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
- val inputSplits = if (ignoreEmptySplits) {
- allInputSplits.filter(_.getLength > 0)
- } else {
- allInputSplits
- }
- val array = new Array[Partition](inputSplits.size)
- for (i <- 0 until inputSplits.size) {
- array(i) = new HadoopPartition(id, i, inputSplits(i))
+ try {
+ val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
+ val inputSplits = if (ignoreEmptySplits) {
+ allInputSplits.filter(_.getLength > 0)
+ } else {
+ allInputSplits
+ }
+ val array = new Array[Partition](inputSplits.size)
+ for (i <- 0 until inputSplits.size) {
+ array(i) = new HadoopPartition(id, i, inputSplits(i))
+ }
+ array
+ } catch {
+ case e: InvalidInputException if ignoreMissingFiles =>
+ logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" +
+ s" partitions returned from this path.", e)
+ Array.empty[Partition]
}
- array
}
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
@@ -256,6 +266,12 @@ class HadoopRDD[K, V](
try {
inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
} catch {
+ case e: FileNotFoundException if ignoreMissingFiles =>
+ logWarning(s"Skipped missing file: ${split.inputSplit}", e)
+ finished = true
+ null
+ // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
+ case e: FileNotFoundException if !ignoreMissingFiles => throw e
case e: IOException if ignoreCorruptFiles =>
logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
finished = true
@@ -276,6 +292,11 @@ class HadoopRDD[K, V](
try {
finished = !reader.next(key, value)
} catch {
+ case e: FileNotFoundException if ignoreMissingFiles =>
+ logWarning(s"Skipped missing file: ${split.inputSplit}", e)
+ finished = true
+ // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
+ case e: FileNotFoundException if !ignoreMissingFiles => throw e
case e: IOException if ignoreCorruptFiles =>
logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
finished = true
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index e4dd1b6a82498..ff66a04859d10 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rdd
-import java.io.IOException
+import java.io.{FileNotFoundException, IOException}
import java.text.SimpleDateFormat
import java.util.{Date, Locale}
@@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException}
import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl}
import org.apache.spark._
@@ -90,6 +90,8 @@ class NewHadoopRDD[K, V](
private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES)
+ private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES)
+
private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS)
def getConf: Configuration = {
@@ -124,17 +126,25 @@ class NewHadoopRDD[K, V](
configurable.setConf(_conf)
case _ =>
}
- val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala
- val rawSplits = if (ignoreEmptySplits) {
- allRowSplits.filter(_.getLength > 0)
- } else {
- allRowSplits
- }
- val result = new Array[Partition](rawSplits.size)
- for (i <- 0 until rawSplits.size) {
- result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ try {
+ val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala
+ val rawSplits = if (ignoreEmptySplits) {
+ allRowSplits.filter(_.getLength > 0)
+ } else {
+ allRowSplits
+ }
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) =
+ new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ } catch {
+ case e: InvalidInputException if ignoreMissingFiles =>
+ logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" +
+ s" partitions returned from this path.", e)
+ Array.empty[Partition]
}
- result
}
override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
@@ -189,6 +199,12 @@ class NewHadoopRDD[K, V](
_reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
_reader
} catch {
+ case e: FileNotFoundException if ignoreMissingFiles =>
+ logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e)
+ finished = true
+ null
+ // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
+ case e: FileNotFoundException if !ignoreMissingFiles => throw e
case e: IOException if ignoreCorruptFiles =>
logWarning(
s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}",
@@ -213,6 +229,11 @@ class NewHadoopRDD[K, V](
try {
finished = !reader.nextKeyValue
} catch {
+ case e: FileNotFoundException if ignoreMissingFiles =>
+ logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e)
+ finished = true
+ // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
+ case e: FileNotFoundException if !ignoreMissingFiles => throw e
case e: IOException if ignoreCorruptFiles =>
logWarning(
s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}",
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala
index 7e14938acd8e0..e2b6df4600590 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala
@@ -34,7 +34,11 @@ import org.apache.spark.util.Utils
* Delivery will only begin when the `start()` method is called. The `stop()` method should be
* called when no more events need to be delivered.
*/
-private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics)
+private class AsyncEventQueue(
+ val name: String,
+ conf: SparkConf,
+ metrics: LiveListenerBusMetrics,
+ bus: LiveListenerBus)
extends SparkListenerBus
with Logging {
@@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
}
private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) {
- try {
- var next: SparkListenerEvent = eventQueue.take()
- while (next != POISON_PILL) {
- val ctx = processingTime.time()
- try {
- super.postToAll(next)
- } finally {
- ctx.stop()
- }
- eventCount.decrementAndGet()
- next = eventQueue.take()
+ var next: SparkListenerEvent = eventQueue.take()
+ while (next != POISON_PILL) {
+ val ctx = processingTime.time()
+ try {
+ super.postToAll(next)
+ } finally {
+ ctx.stop()
}
eventCount.decrementAndGet()
- } catch {
- case ie: InterruptedException =>
- logInfo(s"Stopping listener queue $name.", ie)
+ next = eventQueue.take()
}
+ eventCount.decrementAndGet()
}
override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = {
@@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
eventCount.incrementAndGet()
eventQueue.put(POISON_PILL)
}
- dispatchThread.join()
+ // this thread might be trying to stop itself as part of error handling -- we can't join
+ // in that case.
+ if (Thread.currentThread() != dispatchThread) {
+ dispatchThread.join()
+ }
}
def post(event: SparkListenerEvent): Unit = {
@@ -166,7 +169,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
val prevLastReportTimestamp = lastReportTimestamp
lastReportTimestamp = System.currentTimeMillis()
val previous = new java.util.Date(prevLastReportTimestamp)
- logWarning(s"Dropped $droppedEvents events from $name since $previous.")
+ logWarning(s"Dropped $droppedCount events from $name since $previous.")
}
}
}
@@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
true
}
+ override def removeListenerOnError(listener: SparkListenerInterface): Unit = {
+ // the listener failed in an unrecoverably way, we want to remove it from the entire
+ // LiveListenerBus (potentially stopping a queue if it is empty)
+ bus.removeListener(listener)
+ }
+
}
private object AsyncEventQueue {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
index cd8e61d6d0208..30cf75d43ee09 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala
@@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker (
case Some(a) =>
logInfo(s"Killing blacklisted executor id $exec " +
s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.")
- a.killExecutors(Seq(exec), true, true)
+ a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false,
+ force = true)
case None =>
logWarning(s"Not attempting to kill blacklisted executor id $exec " +
s"since allocation client is not defined.")
@@ -209,7 +210,7 @@ private[scheduler] class BlacklistTracker (
updateNextExpiryTime()
killBlacklistedExecutor(exec)
- val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]())
+ val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]())
blacklistedExecsOnNode += exec
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 199937b8c27af..041eade82d3ca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -206,7 +206,7 @@ class DAGScheduler(
private val messageScheduler =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message")
- private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
+ private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
/**
@@ -1016,15 +1016,24 @@ class DAGScheduler(
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
+ var partitions: Array[Partition] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
- val taskBinaryBytes: Array[Byte] = stage match {
- case stage: ShuffleMapStage =>
- JavaUtils.bufferToArray(
- closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
- case stage: ResultStage =>
- JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
+ var taskBinaryBytes: Array[Byte] = null
+ // taskBinaryBytes and partitions are both effected by the checkpoint status. We need
+ // this synchronization in case another concurrent job is checkpointing this RDD, so we get a
+ // consistent view of both variables.
+ RDDCheckpointData.synchronized {
+ taskBinaryBytes = stage match {
+ case stage: ShuffleMapStage =>
+ JavaUtils.bufferToArray(
+ closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
+ case stage: ResultStage =>
+ JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
+ }
+
+ partitions = stage.rdd.partitions
}
taskBinary = sc.broadcast(taskBinaryBytes)
@@ -1049,7 +1058,7 @@ class DAGScheduler(
stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
- val part = stage.rdd.partitions(id)
+ val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
@@ -1059,7 +1068,7 @@ class DAGScheduler(
case stage: ResultStage =>
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
- val part = stage.rdd.partitions(p)
+ val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
@@ -1083,17 +1092,16 @@ class DAGScheduler(
// the stage as completed here in case there are no tasks to run
markStageAsFinished(stage, None)
- val debugString = stage match {
+ stage match {
case stage: ShuffleMapStage =>
- s"Stage ${stage} is actually done; " +
- s"(available: ${stage.isAvailable}," +
- s"available outputs: ${stage.numAvailableOutputs}," +
- s"partitions: ${stage.numPartitions})"
+ logDebug(s"Stage ${stage} is actually done; " +
+ s"(available: ${stage.isAvailable}," +
+ s"available outputs: ${stage.numAvailableOutputs}," +
+ s"partitions: ${stage.numPartitions})")
+ markMapStageJobsAsFinished(stage)
case stage : ResultStage =>
- s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
+ logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})")
}
- logDebug(debugString)
-
submitWaitingChildStages(stage)
}
}
@@ -1159,9 +1167,7 @@ class DAGScheduler(
*/
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
- val taskId = event.taskInfo.id
val stageId = task.stageId
- val taskType = Utils.getFormattedClassName(task)
outputCommitCoordinator.taskCompleted(
stageId,
@@ -1202,7 +1208,7 @@ class DAGScheduler(
case _ =>
updateAccumulators(event)
}
- case _: ExceptionFailure => updateAccumulators(event)
+ case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
case _ =>
}
postTaskEnd(event)
@@ -1298,13 +1304,7 @@ class DAGScheduler(
shuffleStage.findMissingPartitions().mkString(", "))
submitStage(shuffleStage)
} else {
- // Mark any map-stage jobs waiting on this stage as finished
- if (shuffleStage.mapStageJobs.nonEmpty) {
- val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep)
- for (job <- shuffleStage.mapStageJobs) {
- markMapStageJobAsFinished(job, stats)
- }
- }
+ markMapStageJobsAsFinished(shuffleStage)
submitWaitingChildStages(shuffleStage)
}
}
@@ -1321,7 +1321,7 @@ class DAGScheduler(
"tasks in ShuffleMapStages.")
}
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)
@@ -1409,21 +1409,31 @@ class DAGScheduler(
}
}
- case commitDenied: TaskCommitDenied =>
+ case _: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
- case exceptionFailure: ExceptionFailure =>
+ case _: ExceptionFailure | _: TaskKilled =>
// Nothing left to do, already handled above for accumulator updates.
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.
- case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
+ case _: ExecutorLostFailure | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
}
+ private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = {
+ // Mark any map-stage jobs waiting on this stage as finished
+ if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) {
+ val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep)
+ for (job <- shuffleStage.mapStageJobs) {
+ markMapStageJobAsFinished(job, stats)
+ }
+ }
+ }
+
/**
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
index ba6387a8f08ad..d135190d1e919 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) {
queue.addListener(listener)
case None =>
- val newQueue = new AsyncEventQueue(queue, conf, metrics)
+ val newQueue = new AsyncEventQueue(queue, conf, metrics, this)
newQueue.addListener(listener)
if (started.get()) {
newQueue.start(sparkContext)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index c9cd662f5709d..226c23733c870 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
}
}
} catch {
+ case e: HaltReplayException =>
+ // Just stop replay.
case _: EOFException if maybeTruncated =>
case ioe: IOException =>
throw ioe
@@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
}
}
+ override protected def isIgnorableException(e: Throwable): Boolean = {
+ e.isInstanceOf[HaltReplayException]
+ }
+
}
+/**
+ * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus
+ * only, and will cause errors if thrown when using other bus implementations.
+ */
+private[spark] class HaltReplayException extends RuntimeException
private[spark] object ReplayListenerBus {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 0c11806b3981b..598b62f85a1fa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
* up to launch speculative tasks, etc.
*
* Clients should first call initialize() and start(), then submit task sets through the
- * runTasks method.
+ * submitTasks method.
*
* THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
@@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl(
this(sc, sc.conf.get(config.MAX_TASK_FAILURES))
}
- // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient,
+ // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient,
// because ExecutorAllocationClient is created after this TaskSchedulerImpl.
private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc)
@@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl(
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach(execId =>
@@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl(
}
}
+ /**
+ * Marks the task has completed in all TaskSetManagers for the given stage.
+ *
+ * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
+ * If an earlier attempt of a stage completes a task, we should ensure that the later attempts
+ * do not also submit those same tasks. That also means that a task completion from an earlier
+ * attempt can lead to the entire stage getting marked as successful.
+ */
+ private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
+ taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
+ tsm.markPartitionCompleted(partitionId)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 886c2c99f1ff3..a18c66596852a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -64,8 +64,7 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
- // Limit of bytes for total size of results (default is 1GB)
- val maxResultSize = Utils.getMaxResultSize(conf)
+ val maxResultSize = conf.get(config.MAX_RESULT_SIZE)
val speculationEnabled = conf.getBoolean("spark.speculation", false)
@@ -74,6 +73,8 @@ private[spark] class TaskSetManager(
val ser = env.closureSerializer.newInstance()
val tasks = taskSet.tasks
+ private[scheduler] val partitionToIndex = tasks.zipWithIndex
+ .map { case (t, idx) => t.partitionId -> idx }.toMap
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
@@ -154,7 +155,7 @@ private[spark] class TaskSetManager(
private[scheduler] val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
- private val taskInfos = new HashMap[Long, TaskInfo]
+ private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -288,7 +289,7 @@ private[spark] class TaskSetManager(
None
}
- /** Check whether a task is currently running an attempt on a given host */
+ /** Check whether a task once ran an attempt on a given host */
private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
taskAttempts(taskIndex).exists(_.host == host)
}
@@ -755,6 +756,9 @@ private[spark] class TaskSetManager(
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
" because task " + index + " has already completed successfully")
}
+ // There may be multiple tasksets for this stage -- we let all of them know that the partition
+ // was completed. This may result in some of the tasksets getting completed.
+ sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -765,6 +769,19 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}
+ private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
+ partitionToIndex.get(partitionId).foreach { index =>
+ if (!successful(index)) {
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ isZombie = true
+ }
+ maybeFinishTaskSet()
+ }
+ }
+ }
+
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
@@ -834,13 +851,19 @@ private[spark] class TaskSetManager(
}
ef.exception
+ case tk: TaskKilled =>
+ // TaskKilled might have accumulator updates
+ accumUpdates = tk.accums
+ logWarning(failureReason)
+ None
+
case e: ExecutorLostFailure if !e.exitCausedByApp =>
logInfo(s"Task $tid failed because while it was being computed, its executor " +
"exited for a reason unrelated to the task. Not counting this failure towards the " +
"maximum number of failures for the task.")
None
- case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
+ case e: TaskFailedReason => // TaskResultLost and others
logWarning(failureReason)
None
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 4d75063fbf1c5..d8794e8e551aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
case KillExecutorsOnHost(host) =>
scheduler.getExecutorsAliveOnHost(host).foreach { exec =>
- killExecutors(exec.toSeq, replace = true, force = true)
+ killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false,
+ force = true)
}
case UpdateDelegationTokens(newDelegationTokens) =>
@@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Request that the cluster manager kill the specified executors.
*
- * When asking the executor to be replaced, the executor loss is considered a failure, and
- * killed tasks that are running on the executor will count towards the failure limits. If no
- * replacement is being requested, then the tasks will not count towards the limit.
- *
* @param executorIds identifiers of executors to kill
- * @param replace whether to replace the killed executors with new ones, default false
+ * @param adjustTargetNumExecutors whether the target number of executors be adjusted down
+ * after these executors have been killed
+ * @param countFailures if there are tasks running on the executors when they are killed, whether
+ * those failures be counted to task failure limits?
* @param force whether to force kill busy executors, default false
* @return the ids of the executors acknowledged by the cluster manager to be removed.
*/
final override def killExecutors(
executorIds: Seq[String],
- replace: Boolean,
+ adjustTargetNumExecutors: Boolean,
+ countFailures: Boolean,
force: Boolean): Seq[String] = {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
@@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
val executorsToKill = knownExecutors
.filter { id => !executorsPendingToRemove.contains(id) }
.filter { id => force || !scheduler.isExecutorBusy(id) }
- executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace }
+ executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures }
logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}")
@@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// with the cluster manager to avoid allocating new ones. When computing the new target,
// take into account executors that are pending to be added or removed.
val adjustTotalExecutors =
- if (!replace) {
+ if (adjustTargetNumExecutors) {
requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0)
if (requestedTotalExecutors !=
(numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) {
logDebug(
- s"""killExecutors($executorIds, $replace, $force): Executor counts do not match:
+ s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force):
+ |Executor counts do not match:
|requestedTotalExecutors = $requestedTotalExecutors
|numExistingExecutors = $numExistingExecutors
|numPendingExecutors = $numPendingExecutors
@@ -631,7 +633,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
doRequestTotalExecutors(requestedTotalExecutors)
} else {
- numPendingExecutors += knownExecutors.size
+ numPendingExecutors += executorsToKill.size
Future.successful(true)
}
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
new file mode 100644
index 0000000000000..d15e7937b0523
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import java.io.{DataInputStream, DataOutputStream, InputStream}
+import java.net.Socket
+import java.nio.charset.StandardCharsets.UTF_8
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
+
+/**
+ * A class that can be used to add a simple authentication protocol to socket-based communication.
+ *
+ * The protocol is simple: an auth secret is written to the socket, and the other side checks the
+ * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is
+ * not expected to be valid anymore.
+ *
+ * There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
+ */
+private[spark] class SocketAuthHelper(conf: SparkConf) {
+
+ val secret = Utils.createSecret(conf)
+
+ /**
+ * Read the auth secret from the socket and compare to the expected value. Write the reply back
+ * to the socket.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The client socket.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authClient(s: Socket): Unit = {
+ // Set the socket timeout while checking the auth secret. Reset it before returning.
+ val currentTimeout = s.getSoTimeout()
+ try {
+ s.setSoTimeout(10000)
+ val clientSecret = readUtf8(s)
+ if (secret == clientSecret) {
+ writeUtf8("ok", s)
+ } else {
+ writeUtf8("err", s)
+ JavaUtils.closeQuietly(s)
+ }
+ } finally {
+ s.setSoTimeout(currentTimeout)
+ }
+ }
+
+ /**
+ * Authenticate with a server by writing the auth secret and checking the server's reply.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The socket connected to the server.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authToServer(s: Socket): Unit = {
+ writeUtf8(secret, s)
+
+ val reply = readUtf8(s)
+ if (reply != "ok") {
+ JavaUtils.closeQuietly(s)
+ throw new IllegalArgumentException("Authentication failed.")
+ }
+ }
+
+ protected def readUtf8(s: Socket): String = {
+ val din = new DataInputStream(s.getInputStream())
+ val len = din.readInt()
+ val bytes = new Array[Byte](len)
+ din.readFully(bytes)
+ new String(bytes, UTF_8)
+ }
+
+ protected def writeUtf8(str: String, s: Socket): Unit = {
+ val bytes = str.getBytes(UTF_8)
+ val dout = new DataOutputStream(s.getOutputStream())
+ dout.writeInt(bytes.length)
+ dout.write(bytes, 0, bytes.length)
+ dout.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 538ae05e4eea1..72427dd6ce4d4 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(clazz)
} catch {
case NonFatal(_) => // do nothing
+ case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422.
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 0562d45ff57c5..4103dfb10175e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -90,12 +90,11 @@ private[spark] class BlockStoreShuffleReader[K, C](
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
- require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
- dep.keyOrdering match {
+ val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
@@ -104,9 +103,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener(_ => {
+ sorter.stop()
+ })
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index c5f3f6e2b42b6..d3f1c7ec1bbee 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver(
*/
private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
// the index file should have `block + 1` longs as offset.
- if (index.length() != (blocks + 1) * 8) {
+ if (index.length() != (blocks + 1) * 8L) {
return null
}
val lengths = new Array[Long](blocks)
@@ -202,13 +202,13 @@ private[spark] class IndexShuffleBlockResolver(
// class of issue from re-occurring in the future which is why they are left here even though
// SPARK-22982 is fixed.
val channel = Files.newByteChannel(indexFile.toPath)
- channel.position(blockId.reduceId * 8)
+ channel.position(blockId.reduceId * 8L)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
- val expectedPosition = blockId.reduceId * 8 + 16
+ val expectedPosition = blockId.reduceId * 8L + 16
if (actualPosition != expectedPosition) {
throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
s"expected $expectedPosition but actual position was $actualPosition.")
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index bfb4dc698e325..d9fad64f34c7c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
s"${dependency.serializer.getClass.getName}, does not support object relocation")
false
- } else if (dependency.aggregator.isDefined) {
- log.debug(
- s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
+ } else if (dependency.mapSideCombine) {
+ log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " +
+ s"map-side aggregation")
false
} else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 636b88e792bf3..274399b9cc1f3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
- require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
@@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter {
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
// We cannot bypass sorting if we need to do map-side aggregation.
if (dep.mapSideCombine) {
- require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
false
} else {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
index ab01cddfca5b0..5ea161cd0d151 100644
--- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
@@ -213,11 +213,13 @@ private[spark] class AppStatusListener(
override def onExecutorBlacklistedForStage(
event: SparkListenerExecutorBlacklistedForStage): Unit = {
+ val now = System.nanoTime()
+
Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage =>
- val now = System.nanoTime()
- val esummary = stage.executorSummary(event.executorId)
- esummary.isBlacklisted = true
- maybeUpdate(esummary, now)
+ setStageBlackListStatus(stage, now, event.executorId)
+ }
+ liveExecutors.get(event.executorId).foreach { exec =>
+ addBlackListedStageTo(exec, event.stageId, now)
}
}
@@ -226,16 +228,29 @@ private[spark] class AppStatusListener(
// Implicitly blacklist every available executor for the stage associated with this node
Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage =>
- liveExecutors.values.foreach { exec =>
- if (exec.hostname == event.hostId) {
- val esummary = stage.executorSummary(exec.executorId)
- esummary.isBlacklisted = true
- maybeUpdate(esummary, now)
- }
- }
+ val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq
+ setStageBlackListStatus(stage, now, executorIds: _*)
+ }
+ liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec =>
+ addBlackListedStageTo(exec, event.stageId, now)
}
}
+ private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = {
+ exec.blacklistedInStages += stageId
+ liveUpdate(exec, now)
+ }
+
+ private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = {
+ executorIds.foreach { executorId =>
+ val executorStageSummary = stage.executorSummary(executorId)
+ executorStageSummary.isBlacklisted = true
+ maybeUpdate(executorStageSummary, now)
+ }
+ stage.blackListedExecutors ++= executorIds
+ maybeUpdate(stage, now)
+ }
+
override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = {
updateBlackListStatus(event.executorId, false)
}
@@ -594,12 +609,24 @@ private[spark] class AppStatusListener(
stage.executorSummaries.values.foreach(update(_, now))
update(stage, now, last = true)
+
+ val executorIdsForStage = stage.blackListedExecutors
+ executorIdsForStage.foreach { executorId =>
+ liveExecutors.get(executorId).foreach { exec =>
+ removeBlackListedStageFrom(exec, event.stageInfo.stageId, now)
+ }
+ }
}
appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1)
kvstore.write(appSummary)
}
+ private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = {
+ exec.blacklistedInStages -= stageId
+ liveUpdate(exec, now)
+ }
+
override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = {
// This needs to set fields that are already set by onExecutorAdded because the driver is
// considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event.
@@ -888,7 +915,10 @@ private[spark] class AppStatusListener(
return
}
- val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L)
+ // As the completion time of a skipped stage is always -1, we will remove skipped stages first.
+ // This is safe since the job itself contains enough information to render skipped stages in the
+ // UI.
+ val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime")
val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s =>
s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING
}
diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala
index efc28538a33db..688f25a9fdea1 100644
--- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala
+++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala
@@ -95,7 +95,11 @@ private[spark] class AppStatusStore(
}
def lastStageAttempt(stageId: Int): v1.StageData = {
- val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId)
+ val it = store.view(classOf[StageDataWrapper])
+ .index("stageId")
+ .reverse()
+ .first(stageId)
+ .last(stageId)
.closeableIterator()
try {
if (it.hasNext()) {
diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
index d5f9e19ffdcd0..79e3f13b826ce 100644
--- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
+++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
@@ -20,6 +20,7 @@ package org.apache.spark.status
import java.util.Date
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.immutable.{HashSet, TreeSet}
import scala.collection.mutable.HashMap
import com.google.common.collect.Interners
@@ -254,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
var isBlacklisted = false
+ var blacklistedInStages: Set[Int] = TreeSet()
var executorLogs = Map[String, String]()
@@ -299,7 +301,8 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE
Option(removeTime),
Option(removeReason),
executorLogs,
- memoryMetrics)
+ memoryMetrics,
+ blacklistedInStages)
new ExecutorSummaryWrapper(info)
}
@@ -371,6 +374,8 @@ private class LiveStage extends LiveEntity {
val executorSummaries = new HashMap[String, LiveExecutorStageSummary]()
+ var blackListedExecutors = new HashSet[String]()
+
// Used for cleanup of tasks after they reach the configured limit. Not written to the store.
@volatile var cleaning = false
var savedTasks = new AtomicInteger(0)
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index ed9bdc6e1e3c2..d121068718b8a 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications/{appId}")
def application(): Class[OneApplicationResource] = classOf[OneApplicationResource]
+ @GET
@Path("version")
def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION)
@@ -157,6 +158,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException
.build()
)
+private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException(
+ new ServiceUnavailableException(msg),
+ Response
+ .status(Response.Status.SERVICE_UNAVAILABLE)
+ .entity(ErrorWrapper(msg))
+ .build()
+)
+
private[v1] class BadParameterException(msg: String) extends WebApplicationException(
new IllegalArgumentException(msg),
Response
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
index bd4df07e7afc6..974697890dd03 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala
@@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1
import java.io.OutputStream
import java.util.{List => JList}
import java.util.zip.ZipOutputStream
-import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam}
+import javax.ws.rs._
import javax.ws.rs.core.{MediaType, Response, StreamingOutput}
import scala.util.control.NonFatal
-import org.apache.spark.JobExecutionStatus
-import org.apache.spark.ui.SparkUI
+import org.apache.spark.{JobExecutionStatus, SparkContext}
+import org.apache.spark.ui.UIUtils
@Produces(Array(MediaType.APPLICATION_JSON))
private[v1] class AbstractApplicationResource extends BaseAppResource {
@@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource {
@Path("executors")
def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true))
+ @GET
+ @Path("executors/{executorId}/threads")
+ def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui =>
+ if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) {
+ throw new BadParameterException(
+ s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.")
+ }
+
+ val safeSparkContext = ui.sc.getOrElse {
+ throw new ServiceUnavailable("Thread dumps not available through the history server.")
+ }
+
+ ui.store.asOption(ui.store.executorSummary(execId)) match {
+ case Some(executorSummary) if executorSummary.isActive =>
+ val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse {
+ throw new NotFoundException("No thread dump is available.")
+ }
+ safeThreadDump
+ case Some(_) => throw new BadParameterException("Executor is not active.")
+ case _ => throw new NotFoundException("Executor does not exist.")
+ }
+ }
+
@GET
@Path("allexecutors")
def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false))
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index 550eac3952bbb..971d7e90fa7b8 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1
import java.lang.{Long => JLong}
import java.util.Date
+import scala.xml.{NodeSeq, Text}
+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
@@ -95,7 +97,8 @@ class ExecutorSummary private[spark](
val removeTime: Option[Date],
val removeReason: Option[String],
val executorLogs: Map[String, String],
- val memoryMetrics: Option[MemoryMetrics])
+ val memoryMetrics: Option[MemoryMetrics],
+ val blacklistedInStages: Set[Int])
class MemoryMetrics private[spark](
val usedOnHeapStorageMemory: Long,
@@ -315,3 +318,32 @@ class RuntimeInfo private[spark](
val javaVersion: String,
val javaHome: String,
val scalaVersion: String)
+
+case class StackTrace(elems: Seq[String]) {
+ override def toString: String = elems.mkString
+
+ def html: NodeSeq = {
+ val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) =>
+ if (acc.isEmpty) {
+ acc :+ Text(elem)
+ } else {
+ acc :+ :+ Text(elem)
+ }
+ }
+
+ withNewLine
+ }
+
+ def mkString(start: String, sep: String, end: String): String = {
+ elems.mkString(start, sep, end)
+ }
+}
+
+case class ThreadStackTrace(
+ val threadId: Long,
+ val threadName: String,
+ val threadState: Thread.State,
+ val stackTrace: StackTrace,
+ val blockedByThreadId: Option[Long],
+ val blockedByLock: String,
+ val holdingLocks: Seq[String])
diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala
index 412644d3657b5..646cf25880e37 100644
--- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala
+++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala
@@ -109,6 +109,7 @@ private[spark] object TaskIndexNames {
final val DURATION = "dur"
final val ERROR = "err"
final val EXECUTOR = "exe"
+ final val HOST = "hst"
final val EXEC_CPU_TIME = "ect"
final val EXEC_RUN_TIME = "ert"
final val GC_TIME = "gc"
@@ -165,6 +166,7 @@ private[spark] class TaskDataWrapper(
val duration: Long,
@KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE)
val executorId: String,
+ @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE)
val host: String,
@KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE)
val status: String,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e0276a4dc4224..df1a4bef616b2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -291,7 +291,7 @@ private[spark] class BlockManager(
case e: Exception if i < MAX_ATTEMPTS =>
logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}"
+ s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
- Thread.sleep(SLEEP_TIME_SECS * 1000)
+ Thread.sleep(SLEEP_TIME_SECS * 1000L)
case NonFatal(e) =>
throw new SparkException("Unable to register with external shuffle server due to : " +
e.getMessage, e)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 2c3da0ee85e06..d4a59c33b974c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -18,7 +18,8 @@
package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.util.concurrent.ConcurrentHashMap
+
+import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
@@ -132,10 +133,17 @@ private[spark] object BlockManagerId {
getCachedBlockManagerId(obj)
}
- val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
+ /**
+ * The max cache size is hardcoded to 10000, since the size of a BlockManagerId
+ * object is about 48B, the total memory cost should be below 1MB which is feasible.
+ */
+ val blockManagerIdCache = CacheBuilder.newBuilder()
+ .maximumSize(10000)
+ .build(new CacheLoader[BlockManagerId, BlockManagerId]() {
+ override def load(id: BlockManagerId) = id
+ })
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
- blockManagerIdCache.putIfAbsent(id, id)
blockManagerIdCache.get(id)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 89a6a71a589a1..8e8f7d197c9ef 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint(
val futures = blockManagerInfo.values.map { bm =>
bm.slaveEndpoint.ask[Int](removeMsg).recover {
case e: IOException =>
- logWarning(s"Error trying to remove RDD $rddId", e)
+ logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}",
+ e)
0 // zero blocks were removed
}
}.toSeq
@@ -192,11 +193,16 @@ class BlockManagerMasterEndpoint(
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
removeFromDriver || !info.blockManagerId.isDriver
}
- Future.sequence(
- requiredBlockManagers.map { bm =>
- bm.slaveEndpoint.ask[Int](removeMsg)
- }.toSeq
- )
+ val futures = requiredBlockManagers.map { bm =>
+ bm.slaveEndpoint.ask[Int](removeMsg).recover {
+ case e: IOException =>
+ logWarning(s"Error trying to remove broadcast $broadcastId from block manager " +
+ s"${bm.blockManagerId}", e)
+ 0 // zero blocks were removed
+ }
+ }.toSeq
+
+ Future.sequence(futures)
}
private def removeBlockManager(blockManagerId: BlockManagerId) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
index 353eac60df171..0bacc34cdfd90 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
@@ -54,10 +54,9 @@ trait BlockReplicationPolicy {
}
object BlockReplicationUtils {
- // scalastyle:off line.size.limit
/**
* Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while
- * minimizing space usage. Please see
+ * minimizing space usage. Please see
* here.
*
* @param n total number of indices
@@ -65,7 +64,6 @@ object BlockReplicationUtils {
* @param r random number generator
* @return list of m random unique indices
*/
- // scalastyle:on line.size.limit
private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = {
val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) =>
val t = r.nextInt(i) + 1
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 98b5a735a4529..b31862323a895 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
- * order to throttle the memory usage.
+ * order to throttle the memory usage. Note that zero-sized blocks are
+ * already excluded, which happened in
+ * [[MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
@@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
@@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator(
import ShuffleBlockFetcherIterator._
/**
- * Total number of blocks to fetch. This can be smaller than the total number of blocks
- * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
+ * Total number of blocks to fetch. This should be equal to the total number of blocks
+ * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]].
*
* This should equal localBlocks.size + remoteBlocks.size.
*/
@@ -90,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
private[this] val startTime = System.currentTimeMillis
/** Local blocks to fetch, excluding zero-sized blocks. */
- private[this] val localBlocks = new ArrayBuffer[BlockId]()
+ private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
@@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator(
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
- // Tracks total number of blocks (including zero sized blocks)
- var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
- totalBlocks += blockInfos.size
if (address.executorId == blockManager.blockManagerId.executorId) {
- // Filter out zero-sized blocks
- localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
+ blockInfos.find(_._2 <= 0) match {
+ case Some((blockId, size)) if size < 0 =>
+ throw new BlockException(blockId, "Negative block size " + size)
+ case Some((blockId, size)) if size == 0 =>
+ throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
+ case None => // do nothing.
+ }
+ localBlocks ++= blockInfos.map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
@@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator(
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
- // Skip empty blocks
- if (size > 0) {
+ if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ } else if (size == 0) {
+ throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
+ } else {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
- } else if (size < 0) {
- throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
@@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator(
}
}
}
- logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
+ logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" +
+ s" local blocks and ${remoteBlocks.size} remote blocks")
remoteRequests
}
@@ -316,6 +323,7 @@ final class ShuffleBlockFetcherIterator(
* track in-memory are the ManagedBuffer references themselves.
*/
private[this] def fetchLocalBlocks() {
+ logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
@@ -324,7 +332,8 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
+ buf.size(), buf, false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
@@ -397,12 +406,33 @@ final class ShuffleBlockFetcherIterator(
}
shuffleMetrics.incRemoteBlocksFetched(1)
}
- bytesInFlight -= size
+ if (!localBlocks.contains(blockId)) {
+ bytesInFlight -= size
+ }
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
+ if (buf.size == 0) {
+ // We will never legitimately receive a zero-size block. All blocks with zero records
+ // have zero size and all zero-size blocks have no records (and hence should never
+ // have been requested in the first place). This statement relies on behaviors of the
+ // shuffle writers, which are guaranteed by the following test cases:
+ //
+ // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions"
+ // - UnsafeShuffleWriterSuite: "writeEmptyIterator"
+ // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing"
+ //
+ // There is not an explicit test for SortShuffleWriter but the underlying APIs that
+ // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter
+ // which returns a zero-size from commitAndGet() in case no records were written
+ // since the last call.
+ val msg = s"Received a zero-size buffer for block $blockId from $address " +
+ s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
+ throwFetchFailedException(blockId, address, new IOException(msg))
+ }
+
val in = try {
buf.createInputStream()
} catch {
@@ -583,8 +613,8 @@ object ShuffleBlockFetcherIterator {
* Result of a fetch from a remote block successfully.
* @param blockId block id
* @param address BlockManager that the block was fetched from.
- * @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes.
+ * @param size estimated size of the block. Note that this is NOT the exact bytes.
+ * Size of remote block is used to calculate bytesInFlight.
* @param buf `ManagedBuffer` for the content.
* @param isNetworkReqDone Is this the last network request for this host in this fetch request.
*/
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index e9694fdbca2de..adc406bb1c441 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -24,19 +24,15 @@ import scala.collection.mutable
import sun.nio.ch.DirectBuffer
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
/**
- * :: DeveloperApi ::
* Storage information for each BlockManager.
*
* This class assumes BlockId and BlockStatus are immutable, such that the consumers of this
* class cannot mutate the source of the information. Accesses are not thread-safe.
*/
-@DeveloperApi
-@deprecated("This class may be removed or made private in a future release.", "2.2.0")
-class StorageStatus(
+private[spark] class StorageStatus(
val blockManagerId: BlockManagerId,
val maxMemory: Long,
val maxOnHeapMem: Option[Long],
@@ -44,9 +40,6 @@ class StorageStatus(
/**
* Internal representation of the blocks stored in this block manager.
- *
- * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks.
- * These collections should only be mutated through the add/update/removeBlock methods.
*/
private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]]
private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus]
@@ -87,9 +80,6 @@ class StorageStatus(
*/
def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks }
- /** Return the blocks that belong to the given RDD stored in this block manager. */
- def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty)
-
/** Add the given block to this storage status. If it already exists, overwrite it. */
private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = {
updateStorageInfo(blockId, blockStatus)
@@ -101,46 +91,6 @@ class StorageStatus(
}
}
- /** Update the given block in this storage status. If it doesn't already exist, add it. */
- private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = {
- addBlock(blockId, blockStatus)
- }
-
- /** Remove the given block from this storage status. */
- private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = {
- updateStorageInfo(blockId, BlockStatus.empty)
- blockId match {
- case RDDBlockId(rddId, _) =>
- // Actually remove the block, if it exists
- if (_rddBlocks.contains(rddId)) {
- val removed = _rddBlocks(rddId).remove(blockId)
- // If the given RDD has no more blocks left, remove the RDD
- if (_rddBlocks(rddId).isEmpty) {
- _rddBlocks.remove(rddId)
- }
- removed
- } else {
- None
- }
- case _ =>
- _nonRddBlocks.remove(blockId)
- }
- }
-
- /**
- * Return whether the given block is stored in this block manager in O(1) time.
- *
- * @note This is much faster than `this.blocks.contains`, which is O(blocks) time.
- */
- def containsBlock(blockId: BlockId): Boolean = {
- blockId match {
- case RDDBlockId(rddId, _) =>
- _rddBlocks.get(rddId).exists(_.contains(blockId))
- case _ =>
- _nonRddBlocks.contains(blockId)
- }
- }
-
/**
* Return the given block stored in this block manager in O(1) time.
*
@@ -155,37 +105,12 @@ class StorageStatus(
}
}
- /**
- * Return the number of blocks stored in this block manager in O(RDDs) time.
- *
- * @note This is much faster than `this.blocks.size`, which is O(blocks) time.
- */
- def numBlocks: Int = _nonRddBlocks.size + numRddBlocks
-
- /**
- * Return the number of RDD blocks stored in this block manager in O(RDDs) time.
- *
- * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time.
- */
- def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum
-
- /**
- * Return the number of blocks that belong to the given RDD in O(1) time.
- *
- * @note This is much faster than `this.rddBlocksById(rddId).size`, which is
- * O(blocks in this RDD) time.
- */
- def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0)
-
/** Return the max memory can be used by this block manager. */
def maxMem: Long = maxMemory
/** Return the memory remaining in this block manager. */
def memRemaining: Long = maxMem - memUsed
- /** Return the memory used by caching RDDs */
- def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L)
-
/** Return the memory used by this block manager. */
def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L)
@@ -220,15 +145,9 @@ class StorageStatus(
/** Return the disk space used by this block manager. */
def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum
- /** Return the memory used by the given RDD in this block manager in O(1) time. */
- def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L)
-
/** Return the disk space used by the given RDD in this block manager in O(1) time. */
def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L)
- /** Return the storage level, if any, used by the given RDD in this block manager. */
- def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level)
-
/**
* Update the relevant storage info, taking into account any existing status for this block.
*/
@@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging {
cleaner.clean()
}
}
-
- /**
- * Update the given list of RDDInfo with the given list of storage statuses.
- * This method overwrites the old values stored in the RDDInfo's.
- */
- def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = {
- rddInfos.foreach { rddInfo =>
- val rddId = rddInfo.id
- // Assume all blocks belonging to the same RDD have the same storage level
- val storageLevel = statuses
- .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE)
- val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum
- val memSize = statuses.map(_.memUsedByRdd(rddId)).sum
- val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum
-
- rddInfo.storageLevel = storageLevel
- rddInfo.numCachedPartitions = numCachedPartitions
- rddInfo.memSize = memSize
- rddInfo.diskSize = diskSize
- }
- }
-
- /**
- * Return a mapping from block ID to its locations for each block that belongs to the given RDD.
- */
- def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = {
- val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]]
- statuses.foreach { status =>
- status.rddBlocksById(rddId).foreach { case (bid, _) =>
- val location = status.blockManagerId.hostPort
- blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location
- }
- }
- blockLocations
- }
-
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 0adeb4058b6e4..52a955111231a 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging {
filters.foreach {
case filter : String =>
if (!filter.isEmpty) {
- logInfo("Adding filter: " + filter)
+ logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.")
val holder : FilterHolder = new FilterHolder()
holder.setClassName(filter)
// Get any parameters for each filter
@@ -343,12 +343,14 @@ private[spark] object JettyUtils extends Logging {
-1,
connectionFactories: _*)
connector.setPort(port)
- connector.start()
+ connector.setHost(hostName)
+ connector.setReuseAddress(!Utils.isWindows)
// Currently we only use "SelectChannelConnector"
// Limit the max acceptor number to 8 so that we don't waste a lot of threads
connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8))
- connector.setHost(hostName)
+
+ connector.start()
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2
@@ -405,7 +407,7 @@ private[spark] object JettyUtils extends Logging {
}
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
- ServerInfo(server, httpPort, securePort, collection)
+ ServerInfo(server, httpPort, securePort, conf, collection)
} catch {
case e: Exception =>
server.stop()
@@ -505,10 +507,12 @@ private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
securePort: Option[Int],
+ conf: SparkConf,
private val rootHandler: ContextHandlerCollection) {
- def addHandler(handler: ContextHandler): Unit = {
+ def addHandler(handler: ServletContextHandler): Unit = {
handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME))
+ JettyUtils.addFilters(Seq(handler), conf)
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index b44ac0ea1febc..d315ef66e0dc0 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -65,7 +65,7 @@ private[spark] class SparkUI private (
attachTab(new StorageTab(this, store))
attachTab(new EnvironmentTab(this, store))
attachTab(new ExecutorsTab(this))
- attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
+ addStaticHandler(SparkUI.STATIC_RESOURCE_DIR)
attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath))
attachHandler(ApiRootResource.getServletHandler(this))
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index ba798df13c95d..5d015b0531ef6 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ui
import java.net.URLDecoder
import java.text.SimpleDateFormat
import java.util.{Date, Locale, TimeZone}
+import javax.servlet.http.HttpServletRequest
import scala.util.control.NonFatal
import scala.xml._
@@ -148,60 +149,71 @@ private[spark] object UIUtils extends Logging {
}
// Yarn has to go through a proxy so the base uri is provided and has to be on all links
- def uiRoot: String = {
+ def uiRoot(request: HttpServletRequest): String = {
+ // Knox uses X-Forwarded-Context to notify the application the base path
+ val knoxBasePath = Option(request.getHeader("X-Forwarded-Context"))
// SPARK-11484 - Use the proxyBase set by the AM, if not found then use env.
sys.props.get("spark.ui.proxyBase")
.orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE"))
+ .orElse(knoxBasePath)
.getOrElse("")
}
- def prependBaseUri(basePath: String = "", resource: String = ""): String = {
- uiRoot + basePath + resource
+ def prependBaseUri(
+ request: HttpServletRequest,
+ basePath: String = "",
+ resource: String = ""): String = {
+ uiRoot(request) + basePath + resource
}
- def commonHeaderNodes: Seq[Node] = {
+ def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = {
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
}
- def vizHeaderNodes: Seq[Node] = {
-
-
-
-
-
+ def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = {
+
+
+
+
+
}
- def dataTablesHeaderNodes: Seq[Node] = {
+ def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = {
+
+ href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/>
-
-
-
-
-
-
-
+ href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/>
+
+
+
+
+
+
}
/** Returns a spark page with correctly formatted headers */
def headerSparkPage(
+ request: HttpServletRequest,
title: String,
content: => Seq[Node],
activeTab: SparkUITab,
@@ -214,24 +226,26 @@ private[spark] object UIUtils extends Logging {
val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
val header = activeTab.headerTabs.map { tab =>
-
-
+
+
{org.apache.spark.SPARK_VERSION}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 8b75f5d8fe1a8..2e43f17e6a8e3 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -60,23 +60,25 @@ private[spark] abstract class WebUI(
def getHandlers: Seq[ServletContextHandler] = handlers
def getSecurityManager: SecurityManager = securityManager
- /** Attach a tab to this UI, along with all of its attached pages. */
- def attachTab(tab: WebUITab) {
+ /** Attaches a tab to this UI, along with all of its attached pages. */
+ def attachTab(tab: WebUITab): Unit = {
tab.pages.foreach(attachPage)
tabs += tab
}
- def detachTab(tab: WebUITab) {
+ /** Detaches a tab from this UI, along with all of its attached pages. */
+ def detachTab(tab: WebUITab): Unit = {
tab.pages.foreach(detachPage)
tabs -= tab
}
- def detachPage(page: WebUIPage) {
+ /** Detaches a page from this UI, along with all of its attached handlers. */
+ def detachPage(page: WebUIPage): Unit = {
pageToHandlers.remove(page).foreach(_.foreach(detachHandler))
}
- /** Attach a page to this UI. */
- def attachPage(page: WebUIPage) {
+ /** Attaches a page to this UI. */
+ def attachPage(page: WebUIPage): Unit = {
val pagePath = "/" + page.prefix
val renderHandler = createServletHandler(pagePath,
(request: HttpServletRequest) => page.render(request), securityManager, conf, basePath)
@@ -88,41 +90,41 @@ private[spark] abstract class WebUI(
handlers += renderHandler
}
- /** Attach a handler to this UI. */
- def attachHandler(handler: ServletContextHandler) {
+ /** Attaches a handler to this UI. */
+ def attachHandler(handler: ServletContextHandler): Unit = {
handlers += handler
serverInfo.foreach(_.addHandler(handler))
}
- /** Detach a handler from this UI. */
- def detachHandler(handler: ServletContextHandler) {
+ /** Detaches a handler from this UI. */
+ def detachHandler(handler: ServletContextHandler): Unit = {
handlers -= handler
serverInfo.foreach(_.removeHandler(handler))
}
/**
- * Add a handler for static content.
+ * Detaches the content handler at `path` URI.
*
- * @param resourceBase Root of where to find resources to serve.
- * @param path Path in UI where to mount the resources.
+ * @param path Path in UI to unmount.
*/
- def addStaticHandler(resourceBase: String, path: String): Unit = {
- attachHandler(JettyUtils.createStaticHandler(resourceBase, path))
+ def detachHandler(path: String): Unit = {
+ handlers.find(_.getContextPath() == path).foreach(detachHandler)
}
/**
- * Remove a static content handler.
+ * Adds a handler for static content.
*
- * @param path Path in UI to unmount.
+ * @param resourceBase Root of where to find resources to serve.
+ * @param path Path in UI where to mount the resources.
*/
- def removeStaticHandler(path: String): Unit = {
- handlers.find(_.getContextPath() == path).foreach(detachHandler)
+ def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = {
+ attachHandler(JettyUtils.createStaticHandler(resourceBase, path))
}
- /** Initialize all components of the server. */
+ /** A hook to initialize components of the UI */
def initialize(): Unit
- /** Bind to the HTTP server behind this web interface. */
+ /** Binds to the HTTP server behind this web interface. */
def bind(): Unit = {
assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!")
try {
@@ -136,17 +138,17 @@ private[spark] abstract class WebUI(
}
}
- /** Return the url of web interface. Only valid after bind(). */
+ /** @return The url of web interface. Only valid after [[bind]]. */
def webUrl: String = s"http://$publicHostName:$boundPort"
- /** Return the actual port to which this server is bound. Only valid after bind(). */
+ /** @return The actual port to which this server is bound. Only valid after [[bind]]. */
def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1)
- /** Stop the server behind this web interface. Only valid after bind(). */
+ /** Stops the server behind this web interface. Only valid after [[bind]]. */
def stop(): Unit = {
assert(serverInfo.isDefined,
s"Attempted to stop $className before binding to a server!")
- serverInfo.get.stop()
+ serverInfo.foreach(_.stop())
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
index 902eb92b854f2..3d465a34e44aa 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
@@ -94,7 +94,7 @@ private[ui] class EnvironmentPage(
- UIUtils.headerSparkPage("Environment", content, parent)
+ UIUtils.headerSparkPage(request, "Environment", content, parent)
}
private def propertyHeader = Seq("Name", "Value")
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index f4686ea3cf91f..f9713fb5b4a3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ui.exec
-import java.util.Locale
import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Text}
@@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage(
val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
val content = maybeThreadDump.map { threadDump =>
- val dumpRows = threadDump.sortWith {
- case (threadTrace1, threadTrace2) =>
- val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0
- val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0
- if (v1 == v2) {
- threadTrace1.threadName.toLowerCase(Locale.ROOT) <
- threadTrace2.threadName.toLowerCase(Locale.ROOT)
- } else {
- v1 > v2
- }
- }.map { thread =>
+ val dumpRows = threadDump.map { thread =>
val threadId = thread.threadId
val blockedBy = thread.blockedByThreadId match {
case Some(_) =>
@@ -71,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage(
{thread.threadName}
{thread.threadState}
{blockedBy}{heldLocks}
-
{thread.stackTrace}
+
{thread.stackTrace.html}
}
@@ -108,6 +97,6 @@ private[ui] class ExecutorThreadDumpPage(
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index f4a736d6d439a..bf618b4afbce0 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
+import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
private[spark] case class AccumulatorMetadata(
@@ -199,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
}
override def toString: String = {
+ // getClass.getSimpleName can cause Malformed class name error,
+ // call safer `Utils.getSimpleName` instead
if (metadata == null) {
- "Un-registered Accumulator: " + getClass.getSimpleName
+ "Un-registered Accumulator: " + Utils.getSimpleName(getClass)
} else {
- getClass.getSimpleName + s"(id: $id, name: $name, value: $value)"
+ Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)"
}
}
}
@@ -211,7 +214,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
/**
* An internal class used to track accumulators by Spark itself.
*/
-private[spark] object AccumulatorContext {
+private[spark] object AccumulatorContext extends Logging {
/**
* This global map holds the original accumulator objects that are created on the driver.
@@ -258,13 +261,16 @@ private[spark] object AccumulatorContext {
* Returns the [[AccumulatorV2]] registered with the given ID, if any.
*/
def get(id: Long): Option[AccumulatorV2[_, _]] = {
- Option(originals.get(id)).map { ref =>
- // Since we are storing weak references, we must check whether the underlying data is valid.
+ val ref = originals.get(id)
+ if (ref eq null) {
+ None
+ } else {
+ // Since we are storing weak references, warn when the underlying data is not valid.
val acc = ref.get
if (acc eq null) {
- throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id")
+ logWarning(s"Attempted to access garbage collected accumulator $id")
}
- acc
+ Option(acc)
}
}
@@ -290,7 +296,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private var _count = 0L
/**
- * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * Returns false if this accumulator has had any values added to it or the sum is non-zero.
+ *
* @since 2.0.0
*/
override def isZero: Boolean = _sum == 0L && _count == 0
@@ -368,6 +375,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private var _sum = 0.0
private var _count = 0L
+ /**
+ * Returns false if this accumulator has had any values added to it or the sum is non-zero.
+ */
override def isZero: Boolean = _sum == 0.0 && _count == 0
override def copy(): DoubleAccumulator = {
@@ -441,6 +451,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())
+ /**
+ * Returns false if this accumulator instance has any values in it.
+ */
override def isZero: Boolean = _list.isEmpty
override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator
@@ -479,7 +492,9 @@ class LegacyAccumulatorWrapper[R, T](
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
private[spark] var _value = initialValue // Current value on driver
- override def isZero: Boolean = _value == param.zero(initialValue)
+ @transient private lazy val _zero = param.zero(initialValue)
+
+ override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param)
@@ -488,7 +503,7 @@ class LegacyAccumulatorWrapper[R, T](
}
override def reset(): Unit = {
- _value = param.zero(initialValue)
+ _value = _zero
}
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala
index d73901686b705..4b6602b50aa1c 100644
--- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala
@@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils {
private[spark] var printStream: PrintStream = System.err
// scalastyle:off println
-
- private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
+ private[spark] def printMessage(str: String): Unit = printStream.println(str)
+ // scalastyle:on println
private[spark] def printErrorAndExit(str: String): Unit = {
- printStream.println("Error: " + str)
- printStream.println("Run with --help for usage help or --verbose for debug output")
+ printMessage("Error: " + str)
+ printMessage("Run with --help for usage help or --verbose for debug output")
exitFn(1)
}
- // scalastyle:on println
-
- private[spark] def parseSparkConfProperty(pair: String): (String, String) = {
- pair.split("=", 2).toSeq match {
- case Seq(k, v) => (k, v)
- case _ => printErrorAndExit(s"Spark config without '=': $pair")
- throw new SparkException(s"Spark config without '=': $pair")
- }
- }
-
def main(args: Array[String]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index 31d230d0fec8e..21acaa95c5645 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -22,9 +22,7 @@ package org.apache.spark.util
* through all the elements.
*/
private[spark]
-// scalastyle:off
abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] {
-// scalastyle:on
private[this] var completed = false
def next(): A = sub.next()
diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
index 3ea9139e11027..651ea4996f6cb 100644
--- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala
+++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
@@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging {
private val stopped = new AtomicBoolean(false)
- private val eventThread = new Thread(name) {
+ // Exposed for testing.
+ private[spark] val eventThread = new Thread(name) {
setDaemon(true)
override def run(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index ff83301d631c4..50c6461373dee 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -48,7 +48,7 @@ import org.apache.spark.storage._
* To ensure that we provide these guarantees, follow these rules when modifying these methods:
*
* - Never delete any JSON fields.
- * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields
+ * - Any new JSON fields should be optional; use `jsonOption` when reading these fields
* in `*FromJson` methods.
*/
private[spark] object JsonProtocol {
@@ -407,8 +407,10 @@ private[spark] object JsonProtocol {
("Exit Caused By App" -> exitCausedByApp) ~
("Loss Reason" -> reason.map(_.toString))
case taskKilled: TaskKilled =>
- ("Kill Reason" -> taskKilled.reason)
- case _ => Utils.emptyJson
+ val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList)
+ ("Kill Reason" -> taskKilled.reason) ~
+ ("Accumulator Updates" -> accumUpdates)
+ case _ => emptyJson
}
("Reason" -> reason) ~ json
}
@@ -422,7 +424,7 @@ private[spark] object JsonProtocol {
def jobResultToJson(jobResult: JobResult): JValue = {
val result = Utils.getFormattedClassName(jobResult)
val json = jobResult match {
- case JobSucceeded => Utils.emptyJson
+ case JobSucceeded => emptyJson
case jobFailed: JobFailed =>
JObject("Exception" -> exceptionToJson(jobFailed.exception))
}
@@ -573,7 +575,7 @@ private[spark] object JsonProtocol {
def taskStartFromJson(json: JValue): SparkListenerTaskStart = {
val stageId = (json \ "Stage ID").extract[Int]
val stageAttemptId =
- Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
+ jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskInfo = taskInfoFromJson(json \ "Task Info")
SparkListenerTaskStart(stageId, stageAttemptId, taskInfo)
}
@@ -586,7 +588,7 @@ private[spark] object JsonProtocol {
def taskEndFromJson(json: JValue): SparkListenerTaskEnd = {
val stageId = (json \ "Stage ID").extract[Int]
val stageAttemptId =
- Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
+ jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val taskType = (json \ "Task Type").extract[String]
val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason")
val taskInfo = taskInfoFromJson(json \ "Task Info")
@@ -597,11 +599,11 @@ private[spark] object JsonProtocol {
def jobStartFromJson(json: JValue): SparkListenerJobStart = {
val jobId = (json \ "Job ID").extract[Int]
val submissionTime =
- Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L)
+ jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L)
val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int])
val properties = propertiesFromJson(json \ "Properties")
// The "Stage Infos" field was added in Spark 1.2.0
- val stageInfos = Utils.jsonOption(json \ "Stage Infos")
+ val stageInfos = jsonOption(json \ "Stage Infos")
.map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse {
stageIds.map { id =>
new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")
@@ -613,7 +615,7 @@ private[spark] object JsonProtocol {
def jobEndFromJson(json: JValue): SparkListenerJobEnd = {
val jobId = (json \ "Job ID").extract[Int]
val completionTime =
- Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L)
+ jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L)
val jobResult = jobResultFromJson(json \ "Job Result")
SparkListenerJobEnd(jobId, completionTime, jobResult)
}
@@ -630,15 +632,15 @@ private[spark] object JsonProtocol {
def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = {
val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID")
val maxMem = (json \ "Maximum Memory").extract[Long]
- val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
- val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long])
- val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long])
+ val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
+ val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long])
+ val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long])
SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem)
}
def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = {
val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID")
- val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
+ val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L)
SparkListenerBlockManagerRemoved(time, blockManagerId)
}
@@ -648,11 +650,11 @@ private[spark] object JsonProtocol {
def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = {
val appName = (json \ "App Name").extract[String]
- val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String])
+ val appId = jsonOption(json \ "App ID").map(_.extract[String])
val time = (json \ "Timestamp").extract[Long]
val sparkUser = (json \ "User").extract[String]
- val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String])
- val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson)
+ val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String])
+ val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson)
SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs)
}
@@ -703,19 +705,19 @@ private[spark] object JsonProtocol {
def stageInfoFromJson(json: JValue): StageInfo = {
val stageId = (json \ "Stage ID").extract[Int]
- val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
+ val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0)
val stageName = (json \ "Stage Name").extract[String]
val numTasks = (json \ "Number of Tasks").extract[Int]
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
- val parentIds = Utils.jsonOption(json \ "Parent IDs")
+ val parentIds = jsonOption(json \ "Parent IDs")
.map { l => l.extract[List[JValue]].map(_.extract[Int]) }
.getOrElse(Seq.empty)
- val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("")
- val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
- val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
- val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
+ val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("")
+ val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long])
+ val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long])
+ val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String])
val accumulatedValues = {
- Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match {
+ jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match {
case Some(values) => values.map(accumulableInfoFromJson)
case None => Seq.empty[AccumulableInfo]
}
@@ -735,17 +737,17 @@ private[spark] object JsonProtocol {
def taskInfoFromJson(json: JValue): TaskInfo = {
val taskId = (json \ "Task ID").extract[Long]
val index = (json \ "Index").extract[Int]
- val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1)
+ val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1)
val launchTime = (json \ "Launch Time").extract[Long]
val executorId = (json \ "Executor ID").extract[String].intern()
val host = (json \ "Host").extract[String].intern()
val taskLocality = TaskLocality.withName((json \ "Locality").extract[String])
- val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean])
+ val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean])
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
val finishTime = (json \ "Finish Time").extract[Long]
val failed = (json \ "Failed").extract[Boolean]
- val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean])
- val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match {
+ val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean])
+ val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match {
case Some(values) => values.map(accumulableInfoFromJson)
case None => Seq.empty[AccumulableInfo]
}
@@ -762,13 +764,13 @@ private[spark] object JsonProtocol {
def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
val id = (json \ "ID").extract[Long]
- val name = Utils.jsonOption(json \ "Name").map(_.extract[String])
- val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
- val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
- val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean])
+ val name = jsonOption(json \ "Name").map(_.extract[String])
+ val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
+ val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
+ val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean])
val countFailedValues =
- Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean])
- val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String])
+ jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean])
+ val metadata = jsonOption(json \ "Metadata").map(_.extract[String])
new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata)
}
@@ -821,49 +823,49 @@ private[spark] object JsonProtocol {
metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long])
// Shuffle read metrics
- Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson =>
+ jsonOption(json \ "Shuffle Read Metrics").foreach { readJson =>
val readMetrics = metrics.createTempShuffleReadMetrics()
readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int])
readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int])
readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long])
- Utils.jsonOption(readJson \ "Remote Bytes Read To Disk")
+ jsonOption(readJson \ "Remote Bytes Read To Disk")
.foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])}
readMetrics.incLocalBytesRead(
- Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L))
+ jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L))
readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long])
readMetrics.incRecordsRead(
- Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L))
+ jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L))
metrics.mergeShuffleReadMetrics()
}
// Shuffle write metrics
// TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes.
- Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
+ jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson =>
val writeMetrics = metrics.shuffleWriteMetrics
writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long])
writeMetrics.incRecordsWritten(
- Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L))
+ jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L))
writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long])
}
// Output metrics
- Utils.jsonOption(json \ "Output Metrics").foreach { outJson =>
+ jsonOption(json \ "Output Metrics").foreach { outJson =>
val outputMetrics = metrics.outputMetrics
outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long])
outputMetrics.setRecordsWritten(
- Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L))
+ jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L))
}
// Input metrics
- Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
+ jsonOption(json \ "Input Metrics").foreach { inJson =>
val inputMetrics = metrics.inputMetrics
inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
inputMetrics.incRecordsRead(
- Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L))
+ jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L))
}
// Updated blocks
- Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson =>
+ jsonOption(json \ "Updated Blocks").foreach { blocksJson =>
metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson =>
val id = BlockId((blockJson \ "Block ID").extract[String])
val status = blockStatusFromJson(blockJson \ "Status")
@@ -897,7 +899,7 @@ private[spark] object JsonProtocol {
val shuffleId = (json \ "Shuffle ID").extract[Int]
val mapId = (json \ "Map ID").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
- val message = Utils.jsonOption(json \ "Message").map(_.extract[String])
+ val message = jsonOption(json \ "Message").map(_.extract[String])
new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId,
message.getOrElse("Unknown reason"))
case `exceptionFailure` =>
@@ -905,9 +907,9 @@ private[spark] object JsonProtocol {
val description = (json \ "Description").extract[String]
val stackTrace = stackTraceFromJson(json \ "Stack Trace")
val fullStackTrace =
- Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull
+ jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull
// Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
- val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
+ val accumUpdates = jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
.getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => {
acc.toInfo(Some(acc.value), None)
@@ -915,21 +917,24 @@ private[spark] object JsonProtocol {
ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
case `taskResultLost` => TaskResultLost
case `taskKilled` =>
- val killReason = Utils.jsonOption(json \ "Kill Reason")
+ val killReason = jsonOption(json \ "Kill Reason")
.map(_.extract[String]).getOrElse("unknown reason")
- TaskKilled(killReason)
+ val accumUpdates = jsonOption(json \ "Accumulator Updates")
+ .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
+ .getOrElse(Seq[AccumulableInfo]())
+ TaskKilled(killReason, accumUpdates)
case `taskCommitDenied` =>
// Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
// de/serialization logic was not added until 1.5.1. To provide backward compatibility
// for reading those logs, we need to provide default values for all the fields.
- val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1)
- val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1)
- val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1)
+ val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1)
+ val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1)
+ val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1)
TaskCommitDenied(jobId, partitionId, attemptNo)
case `executorLostFailure` =>
- val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean])
- val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String])
- val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String])
+ val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean])
+ val executorId = jsonOption(json \ "Executor ID").map(_.extract[String])
+ val reason = jsonOption(json \ "Loss Reason").map(_.extract[String])
ExecutorLostFailure(
executorId.getOrElse("Unknown"),
exitCausedByApp.getOrElse(true),
@@ -968,11 +973,11 @@ private[spark] object JsonProtocol {
def rddInfoFromJson(json: JValue): RDDInfo = {
val rddId = (json \ "RDD ID").extract[Int]
val name = (json \ "Name").extract[String]
- val scope = Utils.jsonOption(json \ "Scope")
+ val scope = jsonOption(json \ "Scope")
.map(_.extract[String])
.map(RDDOperationScope.fromJson)
- val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("")
- val parentIds = Utils.jsonOption(json \ "Parent IDs")
+ val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("")
+ val parentIds = jsonOption(json \ "Parent IDs")
.map { l => l.extract[List[JValue]].map(_.extract[Int]) }
.getOrElse(Seq.empty)
val storageLevel = storageLevelFromJson(json \ "Storage Level")
@@ -1029,7 +1034,7 @@ private[spark] object JsonProtocol {
}
def propertiesFromJson(json: JValue): Properties = {
- Utils.jsonOption(json).map { value =>
+ jsonOption(json).map { value =>
val properties = new Properties
mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) }
properties
@@ -1058,4 +1063,14 @@ private[spark] object JsonProtocol {
e
}
+ /** Return an option that translates JNothing to None */
+ private def jsonOption(json: JValue): Option[JValue] = {
+ json match {
+ case JNothing => None
+ case value: JValue => Some(value)
+ }
+ }
+
+ private def emptyJson: JObject = JObject(List[JField]())
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index 76a56298aaebc..d4474a90b26f1 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
}
+ /**
+ * This can be overriden by subclasses if there is any extra cleanup to do when removing a
+ * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.
+ */
+ def removeListenerOnError(listener: L): Unit = {
+ removeListener(listener)
+ }
+
+
/**
* Post the event to all registered listeners. The `postToAll` caller should guarantee calling
* `postToAll` in the same thread for all events.
@@ -80,8 +89,17 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
}
try {
doPostEvent(listener, event)
+ if (Thread.interrupted()) {
+ // We want to throw the InterruptedException right away so we can associate the interrupt
+ // with this listener, as opposed to waiting for a queue.take() etc. to detect it.
+ throw new InterruptedException()
+ }
} catch {
- case NonFatal(e) =>
+ case ie: InterruptedException =>
+ logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " +
+ s"Removing that listener.", ie)
+ removeListenerOnError(listener)
+ case NonFatal(e) if !isIgnorableException(e) =>
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
} finally {
if (maybeTimerContext != null) {
@@ -97,6 +115,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
*/
protected def doPostEvent(listener: L, event: E): Unit
+ /** Allows bus implementations to prevent error logging for certain exceptions. */
+ protected def isIgnorableException(e: Throwable): Boolean = false
+
private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = {
val c = implicitly[ClassTag[T]].runtimeClass
listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq
diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
index 4001fac3c3d5a..b702838fa257f 100644
--- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
+++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
@@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging {
}
/**
- * Adds a shutdown hook with the given priority. Hooks with lower priority values run
+ * Adds a shutdown hook with the given priority. Hooks with higher priority values run
* first.
*
* @param hook The code to run during shutdown.
diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala
new file mode 100644
index 0000000000000..1aa2009fa9b5b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+/**
+ * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch
+ * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw
+ * SparkFatalException, which wraps the fatal throwable inside.
+ * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future},
+ * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch
+ * it and re-throw the original exception/error.
+ */
+private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
index e0f5af5250e7f..1b34fbde38cd6 100644
--- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException:
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!ShutdownHookManager.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(SparkExitCode.OOM)
- } else if (exitOnUncaughtException) {
- System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
+ exception match {
+ case _: OutOfMemoryError =>
+ System.exit(SparkExitCode.OOM)
+ case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] =>
+ // SPARK-24294: This is defensive code, in case that SparkFatalException is
+ // misused and uncaught.
+ System.exit(SparkExitCode.OOM)
+ case _ if exitOnUncaughtException =>
+ System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 81aaf79db0c13..0f08a2b0ad895 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -19,13 +19,12 @@ package org.apache.spark.util
import java.util.concurrent._
+import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
-import scala.concurrent.duration.Duration
+import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal
-import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
-
import org.apache.spark.SparkException
private[spark] object ThreadUtils {
@@ -103,6 +102,22 @@ private[spark] object ThreadUtils {
executor
}
+ /**
+ * Wrapper over ScheduledThreadPoolExecutor.
+ */
+ def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int)
+ : ScheduledExecutorService = {
+ val threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(s"$threadNamePrefix-%d")
+ .build()
+ val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory)
+ // By default, a cancelled task is not automatically removed from the work queue until its delay
+ // elapses. We have to enable it manually.
+ executor.setRemoveOnCancelPolicy(true)
+ executor
+ }
+
/**
* Run a piece of code in a new thread and return the result. Exception in the new thread is
* thrown in the caller thread with an adjusted stack trace that removes references to this
@@ -200,6 +215,8 @@ private[spark] object ThreadUtils {
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
awaitable.result(atMost)(awaitPermission)
} catch {
+ case e: SparkFatalException =>
+ throw e.throwable
// TimeoutException is thrown in the current thread, so not need to warp the exception.
case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
@@ -227,4 +244,14 @@ private[spark] object ThreadUtils {
}
}
// scalastyle:on awaitready
+
+ def shutdown(
+ executor: ExecutorService,
+ gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = {
+ executor.shutdown()
+ executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS)
+ if (!executor.isShutdown) {
+ executor.shutdownNow()
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 5853302973140..c139db46b63a3 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,8 @@
package org.apache.spark.util
import java.io._
+import java.lang.{Byte => JByte}
+import java.lang.InternalError
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
@@ -25,12 +27,13 @@ import java.net._
import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel}
import java.nio.charset.StandardCharsets
-import java.nio.file.{Files, Paths}
+import java.nio.file.Files
+import java.security.SecureRandom
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
+import java.util.concurrent.TimeUnit.NANOSECONDS
import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.GZIPInputStream
-import javax.net.ssl.HttpsURLConnection
import scala.annotation.tailrec
import scala.collection.JavaConverters._
@@ -44,6 +47,7 @@ import scala.util.matching.Regex
import _root_.io.netty.channel.unix.Errors.NativeIoException
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import com.google.common.hash.HashCodes
import com.google.common.io.{ByteStreams, Files => GFiles}
import com.google.common.net.InetAddresses
import org.apache.commons.lang3.SystemUtils
@@ -51,9 +55,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
-import org.json4s._
import org.slf4j.Logger
import org.apache.spark._
@@ -63,6 +65,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
+import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace}
/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)
@@ -432,7 +435,7 @@ private[spark] object Utils extends Logging {
new URI("file:///" + rawFileName).getPath.substring(1)
}
- /**
+ /**
* Download a file or directory to target directory. Supports fetching the file in a variety of
* ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible
@@ -505,6 +508,14 @@ private[spark] object Utils extends Logging {
targetFile
}
+ /** Records the duration of running `body`. */
+ def timeTakenMs[T](body: => T): (T, Long) = {
+ val startTime = System.nanoTime()
+ val result = body
+ val endTime = System.nanoTime()
+ (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0))
+ }
+
/**
* Download `in` to `tempFile`, then move it to `destFile`.
*
@@ -672,7 +683,6 @@ private[spark] object Utils extends Logging {
logDebug("fetchFile not using security")
uc = new URL(url).openConnection()
}
- Utils.setupSecureURLConnection(uc, securityMgr)
val timeoutMs =
conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000
@@ -810,15 +820,15 @@ private[spark] object Utils extends Logging {
conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator)
} else if (conf.getenv("SPARK_LOCAL_DIRS") != null) {
conf.getenv("SPARK_LOCAL_DIRS").split(",")
- } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) {
+ } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) {
// Mesos already creates a directory per Mesos task. Spark should use that directory
// instead so all temporary files are automatically cleaned up when the Mesos task ends.
// Note that we don't want this if the shuffle service is enabled because we want to
// continue to serve shuffle files after the executors that wrote them have already exited.
- Array(conf.getenv("MESOS_DIRECTORY"))
+ Array(conf.getenv("MESOS_SANDBOX"))
} else {
- if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) {
- logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " +
+ if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) {
+ logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " +
"spark.shuffle.service.enabled is enabled.")
}
// In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user
@@ -1017,70 +1027,18 @@ private[spark] object Utils extends Logging {
" " + (System.currentTimeMillis - startTimeMs) + " ms"
}
- private def listFilesSafely(file: File): Seq[File] = {
- if (file.exists()) {
- val files = file.listFiles()
- if (files == null) {
- throw new IOException("Failed to list files for dir: " + file)
- }
- files
- } else {
- List()
- }
- }
-
- /**
- * Lists files recursively.
- */
- def recursiveList(f: File): Array[File] = {
- require(f.isDirectory)
- val current = f.listFiles
- current ++ current.filter(_.isDirectory).flatMap(recursiveList)
- }
-
/**
* Delete a file or directory and its contents recursively.
* Don't follow directories if they are symlinks.
* Throws an exception if deletion is unsuccessful.
*/
- def deleteRecursively(file: File) {
+ def deleteRecursively(file: File): Unit = {
if (file != null) {
- try {
- if (file.isDirectory && !isSymlink(file)) {
- var savedIOException: IOException = null
- for (child <- listFilesSafely(file)) {
- try {
- deleteRecursively(child)
- } catch {
- // In case of multiple exceptions, only last one will be thrown
- case ioe: IOException => savedIOException = ioe
- }
- }
- if (savedIOException != null) {
- throw savedIOException
- }
- ShutdownHookManager.removeShutdownDeleteDir(file)
- }
- } finally {
- if (file.delete()) {
- logTrace(s"${file.getAbsolutePath} has been deleted")
- } else {
- // Delete can also fail if the file simply did not exist
- if (file.exists()) {
- throw new IOException("Failed to delete: " + file.getAbsolutePath)
- }
- }
- }
+ JavaUtils.deleteRecursively(file)
+ ShutdownHookManager.removeShutdownDeleteDir(file)
}
}
- /**
- * Check to see if file is a symbolic link.
- */
- def isSymlink(file: File): Boolean = {
- return Files.isSymbolicLink(Paths.get(file.toURI))
- }
-
/**
* Determines if a directory contains any files newer than cutoff seconds.
*
@@ -1828,7 +1786,7 @@ private[spark] object Utils extends Logging {
* [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower
* in the current version of Scala.
*/
- def getIteratorSize[T](iterator: Iterator[T]): Long = {
+ def getIteratorSize(iterator: Iterator[_]): Long = {
var count = 0L
while (iterator.hasNext) {
count += 1L
@@ -1872,20 +1830,9 @@ private[spark] object Utils extends Logging {
/** Return the class name of the given object, removing all dollar signs */
def getFormattedClassName(obj: AnyRef): String = {
- obj.getClass.getSimpleName.replace("$", "")
- }
-
- /** Return an option that translates JNothing to None */
- def jsonOption(json: JValue): Option[JValue] = {
- json match {
- case JNothing => None
- case value: JValue => Some(value)
- }
+ getSimpleName(obj.getClass).replace("$", "")
}
- /** Return an empty JSON object */
- def emptyJson: JsonAST.JObject = JObject(List[JField]())
-
/**
* Return a Hadoop FileSystem with the scheme encoded in the given path.
*/
@@ -1900,15 +1847,6 @@ private[spark] object Utils extends Logging {
getHadoopFileSystem(new URI(path), conf)
}
- /**
- * Return the absolute path of a file in the given directory.
- */
- def getFilePath(dir: File, fileName: String): Path = {
- assert(dir.isDirectory)
- val path = new File(dir, fileName).getAbsolutePath
- new Path(path)
- }
-
/**
* Whether the underlying operating system is Windows.
*/
@@ -1931,13 +1869,6 @@ private[spark] object Utils extends Logging {
sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing")
}
- /**
- * Strip the directory from a path name
- */
- def stripDirectory(path: String): String = {
- new File(path).getName
- }
-
/**
* Terminates a process waiting for at most the specified duration.
*
@@ -2168,7 +2099,22 @@ private[spark] object Utils extends Logging {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
- threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
+ threadInfos.sortWith { case (threadTrace1, threadTrace2) =>
+ val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0
+ val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0
+ if (v1 == v2) {
+ val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT)
+ val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT)
+ val nameCmpRes = name1.compareTo(name2)
+ if (nameCmpRes == 0) {
+ threadTrace1.getThreadId < threadTrace2.getThreadId
+ } else {
+ nameCmpRes < 0
+ }
+ } else {
+ v1 > v2
+ }
+ }.map(threadInfoToThreadStackTrace)
}
def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
@@ -2184,14 +2130,14 @@ private[spark] object Utils extends Logging {
private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = {
val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap
- val stackTrace = threadInfo.getStackTrace.map { frame =>
+ val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame =>
monitors.get(frame) match {
case Some(monitor) =>
monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}"
case None =>
frame.toString
}
- }.mkString("\n")
+ })
// use a set to dedup re-entrant locks that are held at multiple places
val heldLocks =
@@ -2333,50 +2279,6 @@ private[spark] object Utils extends Logging {
org.apache.log4j.Logger.getRootLogger().setLevel(l)
}
- /**
- * config a log4j properties used for testsuite
- */
- def configTestLog4j(level: String): Unit = {
- val pro = new Properties()
- pro.put("log4j.rootLogger", s"$level, console")
- pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender")
- pro.put("log4j.appender.console.target", "System.err")
- pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout")
- pro.put("log4j.appender.console.layout.ConversionPattern",
- "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")
- PropertyConfigurator.configure(pro)
- }
-
- /**
- * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and
- * the host verifier from the given security manager.
- */
- def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = {
- urlConnection match {
- case https: HttpsURLConnection =>
- sm.sslSocketFactory.foreach(https.setSSLSocketFactory)
- sm.hostnameVerifier.foreach(https.setHostnameVerifier)
- https
- case connection => connection
- }
- }
-
- def invoke(
- clazz: Class[_],
- obj: AnyRef,
- methodName: String,
- args: (Class[_], AnyRef)*): AnyRef = {
- val (types, values) = args.unzip
- val method = clazz.getDeclaredMethod(methodName, types: _*)
- method.setAccessible(true)
- method.invoke(obj, values.toSeq: _*)
- }
-
- // Limit of bytes for total size of results (default is 1GB)
- def getMaxResultSize(conf: SparkConf): Long = {
- memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20
- }
-
/**
* Return the current system LD_LIBRARY_PATH name
*/
@@ -2412,16 +2314,20 @@ private[spark] object Utils extends Logging {
}
/**
- * Return the value of a config either through the SparkConf or the Hadoop configuration
- * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf
- * if the key is not set in the Hadoop configuration.
+ * Return the value of a config either through the SparkConf or the Hadoop configuration.
+ * We Check whether the key is set in the SparkConf before look at any Hadoop configuration.
+ * If the key is set in SparkConf, no matter whether it is running on YARN or not,
+ * gets the value from SparkConf.
+ * Only when the key is not set in SparkConf and running on YARN,
+ * gets the value from Hadoop configuration.
*/
def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = {
- val sparkValue = conf.get(key, default)
- if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") {
- new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue)
+ if (conf.contains(key)) {
+ conf.get(key, default)
+ } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") {
+ new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default)
} else {
- sparkValue
+ default
}
}
@@ -2609,16 +2515,6 @@ private[spark] object Utils extends Logging {
SignalUtils.registerLogger(log)
}
- /**
- * Unions two comma-separated lists of files and filters out empty strings.
- */
- def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = {
- var allFiles = Set.empty[String]
- leftList.foreach { value => allFiles ++= value.split(",") }
- rightList.foreach { value => allFiles ++= value.split(",") }
- allFiles.filter { _.nonEmpty }
- }
-
/**
* Return the jar files pointed by the "spark.jars" property. Spark internally will distribute
* these jars through file server. In the YARN mode, it will return an empty list, since YARN
@@ -2805,6 +2701,86 @@ private[spark] object Utils extends Logging {
s"k8s://$resolvedURL"
}
+
+ /**
+ * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id
+ * and {{APP_ID}} occurrences with the App Id.
+ */
+ def substituteAppNExecIds(opt: String, appId: String, execId: String): String = {
+ opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId)
+ }
+
+ /**
+ * Replaces all the {{APP_ID}} occurrences with the App Id.
+ */
+ def substituteAppId(opt: String, appId: String): String = {
+ opt.replace("{{APP_ID}}", appId)
+ }
+
+ def createSecret(conf: SparkConf): String = {
+ val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
+ val rnd = new SecureRandom()
+ val secretBytes = new Array[Byte](bits / JByte.SIZE)
+ rnd.nextBytes(secretBytes)
+ HashCodes.fromBytes(secretBytes).toString()
+ }
+
+ /**
+ * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala.
+ * This method mimicks scalatest's getSimpleNameOfAnObjectsClass.
+ */
+ def getSimpleName(cls: Class[_]): String = {
+ try {
+ return cls.getSimpleName
+ } catch {
+ case err: InternalError => return stripDollars(stripPackages(cls.getName))
+ }
+ }
+
+ /**
+ * Remove the packages from full qualified class name
+ */
+ private def stripPackages(fullyQualifiedName: String): String = {
+ fullyQualifiedName.split("\\.").takeRight(1)(0)
+ }
+
+ /**
+ * Remove trailing dollar signs from qualified class name,
+ * and return the trailing part after the last dollar sign in the middle
+ */
+ private def stripDollars(s: String): String = {
+ val lastDollarIndex = s.lastIndexOf('$')
+ if (lastDollarIndex < s.length - 1) {
+ // The last char is not a dollar sign
+ if (lastDollarIndex == -1 || !s.contains("$iw")) {
+ // The name does not have dollar sign or is not an intepreter
+ // generated class, so we should return the full string
+ s
+ } else {
+ // The class name is intepreter generated,
+ // return the part after the last dollar sign
+ // This is the same behavior as getClass.getSimpleName
+ s.substring(lastDollarIndex + 1)
+ }
+ }
+ else {
+ // The last char is a dollar sign
+ // Find last non-dollar char
+ val lastNonDollarChar = s.reverse.find(_ != '$')
+ lastNonDollarChar match {
+ case None => s
+ case Some(c) =>
+ val lastNonDollarIndex = s.lastIndexOf(c)
+ if (lastNonDollarIndex == -1) {
+ s
+ } else {
+ // Strip the trailing dollar signs
+ // Invoke stripDollars again to get the simple name
+ stripDollars(s.substring(0, lastNonDollarIndex + 1))
+ }
+ }
+ }
+ }
}
private[util] object CallerContext extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 8183f825592c0..81457b53cd814 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util.collection
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
/**
@@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
protected def forceSpill(): Boolean
// Number of elements read from input since last spill
- protected def elementsRead: Long = _elementsRead
+ protected def elementsRead: Int = _elementsRead
// Called by subclasses every time a record is read
// It's used for checking spilling frequency
@@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
// Force this collection to spill when there are this many elements in memory
// For testing only
- private[this] val numElementsForceSpillThreshold: Long =
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)
+ private[this] val numElementsForceSpillThreshold: Int =
+ SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD)
// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
@volatile private[this] var myMemoryThreshold = initialMemoryThreshold
// Number of elements read from input since last spill
- private[this] var _elementsRead = 0L
+ private[this] var _elementsRead = 0
// Number of bytes spilled in total
@volatile private[this] var _memoryBytesSpilled = 0L
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 7367af7888bd8..700ce56466c35 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -63,10 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
*/
def writeFully(channel: WritableByteChannel): Unit = {
for (bytes <- getChunks()) {
- while (bytes.remaining() > 0) {
+ val originalLimit = bytes.limit()
+ while (bytes.hasRemaining) {
+ // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct
+ // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread.
+ // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may
+ // cause significant native memory leak, if a large direct ByteBuffer is allocated and
+ // cached, as it's never released until thread exits. Here we write the `bytes` with
+ // fixed-size slices to limit the size of the cached direct ByteBuffer.
+ // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details.
val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
bytes.limit(bytes.position() + ioSize)
channel.write(bytes)
+ bytes.limit(originalLimit)
}
}
}
diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
index 3440e1aea2f46..22db3592ecc96 100644
--- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
@@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite {
protected File inputFile;
- protected InputStream inputStream;
+ protected InputStream[] inputStreams;
@Before
public void setUp() throws IOException {
@@ -54,77 +54,91 @@ public void tearDown() {
@Test
public void testReadOneByte() throws IOException {
- for (int i = 0; i < randomBytes.length; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ for (int i = 0; i < randomBytes.length; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
}
}
@Test
public void testReadMultipleBytes() throws IOException {
- byte[] readBytes = new byte[8 * 1024];
- int i = 0;
- while (i < randomBytes.length) {
- int read = inputStream.read(readBytes, 0, 8 * 1024);
- for (int j = 0; j < read; j++) {
- assertEquals(randomBytes[i], readBytes[j]);
- i++;
+ for (InputStream inputStream: inputStreams) {
+ byte[] readBytes = new byte[8 * 1024];
+ int i = 0;
+ while (i < randomBytes.length) {
+ int read = inputStream.read(readBytes, 0, 8 * 1024);
+ for (int j = 0; j < read; j++) {
+ assertEquals(randomBytes[i], readBytes[j]);
+ i++;
+ }
}
}
}
@Test
public void testBytesSkipped() throws IOException {
- assertEquals(1024, inputStream.skip(1024));
- for (int i = 1024; i < randomBytes.length; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ assertEquals(1024, inputStream.skip(1024));
+ for (int i = 1024; i < randomBytes.length; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
}
}
@Test
public void testBytesSkippedAfterRead() throws IOException {
- for (int i = 0; i < 1024; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
- }
- assertEquals(1024, inputStream.skip(1024));
- for (int i = 2048; i < randomBytes.length; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ for (int i = 0; i < 1024; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
+ assertEquals(1024, inputStream.skip(1024));
+ for (int i = 2048; i < randomBytes.length; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
}
}
@Test
public void testNegativeBytesSkippedAfterRead() throws IOException {
- for (int i = 0; i < 1024; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
- }
- // Skipping negative bytes should essential be a no-op
- assertEquals(0, inputStream.skip(-1));
- assertEquals(0, inputStream.skip(-1024));
- assertEquals(0, inputStream.skip(Long.MIN_VALUE));
- assertEquals(1024, inputStream.skip(1024));
- for (int i = 2048; i < randomBytes.length; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ for (int i = 0; i < 1024; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
+ // Skipping negative bytes should essential be a no-op
+ assertEquals(0, inputStream.skip(-1));
+ assertEquals(0, inputStream.skip(-1024));
+ assertEquals(0, inputStream.skip(Long.MIN_VALUE));
+ assertEquals(1024, inputStream.skip(1024));
+ for (int i = 2048; i < randomBytes.length; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
}
}
@Test
public void testSkipFromFileChannel() throws IOException {
- // Since the buffer is smaller than the skipped bytes, this will guarantee
- // we skip from underlying file channel.
- assertEquals(1024, inputStream.skip(1024));
- for (int i = 1024; i < 2048; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
- }
- assertEquals(256, inputStream.skip(256));
- assertEquals(256, inputStream.skip(256));
- assertEquals(512, inputStream.skip(512));
- for (int i = 3072; i < randomBytes.length; i++) {
- assertEquals(randomBytes[i], (byte) inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ // Since the buffer is smaller than the skipped bytes, this will guarantee
+ // we skip from underlying file channel.
+ assertEquals(1024, inputStream.skip(1024));
+ for (int i = 1024; i < 2048; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
+ assertEquals(256, inputStream.skip(256));
+ assertEquals(256, inputStream.skip(256));
+ assertEquals(512, inputStream.skip(512));
+ for (int i = 3072; i < randomBytes.length; i++) {
+ assertEquals(randomBytes[i], (byte) inputStream.read());
+ }
}
}
@Test
public void testBytesSkippedAfterEOF() throws IOException {
- assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
- assertEquals(-1, inputStream.read());
+ for (InputStream inputStream: inputStreams) {
+ assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
+ assertEquals(-1, inputStream.read());
+ }
}
}
diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
index 211b33a1a9fb0..a320f8662f707 100644
--- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
@@ -18,6 +18,7 @@
import org.junit.Before;
+import java.io.InputStream;
import java.io.IOException;
/**
@@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite {
@Before
public void setUp() throws IOException {
super.setUp();
- inputStream = new NioBufferedFileInputStream(inputFile);
+ inputStreams = new InputStream[] {
+ new NioBufferedFileInputStream(inputFile), // default
+ new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer
+ };
}
}
diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
index 918ddc4517ec4..bfa1e0b908824 100644
--- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
@@ -19,16 +19,27 @@
import org.junit.Before;
import java.io.IOException;
+import java.io.InputStream;
/**
- * Tests functionality of {@link NioBufferedFileInputStream}
+ * Tests functionality of {@link ReadAheadInputStreamSuite}
*/
public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite {
@Before
public void setUp() throws IOException {
super.setUp();
- inputStream = new ReadAheadInputStream(
- new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024);
+ inputStreams = new InputStream[] {
+ // Tests equal and aligned buffers of wrapped an outer stream.
+ new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024),
+ // Tests aligned buffers, wrapped bigger than outer.
+ new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024),
+ // Tests aligned buffers, wrapped smaller than outer.
+ new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024),
+ // Tests unaligned buffers, wrapped bigger than outer.
+ new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123),
+ // Tests unaligned buffers, wrapped smaller than outer.
+ new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321)
+ };
}
}
diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
index 2225591a4ff75..6a1a38c1a54f4 100644
--- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
+++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
@@ -109,7 +109,7 @@ public void testChildProcLauncher() throws Exception {
.addSparkArg(opts.CONF,
String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS))
.setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS,
- "-Dfoo=bar -Dtest.appender=childproc")
+ "-Dfoo=bar -Dtest.appender=console")
.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
.addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow")
.setMainClass(SparkLauncherTestApp.class.getName())
@@ -192,6 +192,41 @@ private void inProcessLauncherTestImpl() throws Exception {
}
}
+ @Test
+ public void testInProcessLauncherDoesNotKillJvm() throws Exception {
+ SparkSubmitOptionParser opts = new SparkSubmitOptionParser();
+ List wrongArgs = Arrays.asList(
+ new String[] { "--unknown" },
+ new String[] { opts.DEPLOY_MODE, "invalid" });
+
+ for (String[] args : wrongArgs) {
+ InProcessLauncher launcher = new InProcessLauncher()
+ .setAppResource(SparkLauncher.NO_RESOURCE);
+ switch (args.length) {
+ case 2:
+ launcher.addSparkArg(args[0], args[1]);
+ break;
+
+ case 1:
+ launcher.addSparkArg(args[0]);
+ break;
+
+ default:
+ fail("FIXME: invalid test.");
+ }
+
+ SparkAppHandle handle = launcher.startApplication();
+ waitFor(handle);
+ assertEquals(SparkAppHandle.State.FAILED, handle.getState());
+ }
+
+ // Run --version, which is useless as a use case, but should succeed and not exit the JVM.
+ // The expected state is "LOST" since "--version" doesn't report state back to the handle.
+ SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication();
+ waitFor(handle);
+ assertEquals(SparkAppHandle.State.LOST, handle.getState());
+ }
+
public static class SparkLauncherTestApp {
public static void main(String[] args) throws Exception {
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index a0664b30d6cc2..d7d2d0b012bd3 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() {
final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
final MemoryBlock dataPage = manager.allocatePage(256, c);
c.freePage(dataPage);
- Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber);
+ Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber());
}
@Test(expected = AssertionError.class)
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 24a55df84a240..0d5c5ea7903e9 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -95,7 +95,7 @@ public void tearDown() {
@SuppressWarnings("unchecked")
public void setUp() throws IOException {
MockitoAnnotations.initMocks(this);
- tempDir = Utils.createTempDir("test", "test");
+ tempDir = Utils.createTempDir(null, "test");
mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index c145532328514..85ffdca436e14 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -129,7 +129,6 @@ public int compare(
final UnsafeSorterIterator iter = sorter.getSortedIterator();
int iterLength = 0;
long prevPrefix = -1;
- Arrays.sort(dataToSort);
while (iter.hasNext()) {
iter.loadNext();
final String str =
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
index 942e6d8f04363..7bb8fe8fd8f98 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
@@ -19,5 +19,6 @@
"isBlacklisted" : false,
"maxMemory" : 278302556,
"addTime" : "2015-02-03T16:43:00.906GMT",
- "executorLogs" : { }
+ "executorLogs" : { },
+ "blacklistedInStages" : [ ]
} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
index ed33c90dd39ba..dd5b1dcb7372b 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
@@ -25,7 +25,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "3",
"hostPort" : "172.22.0.167:51485",
@@ -56,7 +57,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
} ,{
"id" : "2",
"hostPort" : "172.22.0.167:51487",
@@ -87,7 +89,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "1",
"hostPort" : "172.22.0.167:51490",
@@ -118,7 +121,8 @@
"usedOffHeapStorageMemory": 0,
"totalOnHeapStorageMemory": 384093388,
"totalOffHeapStorageMemory": 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "0",
"hostPort" : "172.22.0.167:51491",
@@ -149,5 +153,6 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
index 73519f1d9e2e4..3e55d3d9d7eb9 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
@@ -25,7 +25,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "3",
"hostPort" : "172.22.0.167:51485",
@@ -56,7 +57,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "2",
"hostPort" : "172.22.0.167:51487",
@@ -87,7 +89,8 @@
"usedOffHeapStorageMemory" : 0,
"totalOnHeapStorageMemory" : 384093388,
"totalOffHeapStorageMemory" : 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "1",
"hostPort" : "172.22.0.167:51490",
@@ -118,7 +121,8 @@
"usedOffHeapStorageMemory": 0,
"totalOnHeapStorageMemory": 384093388,
"totalOffHeapStorageMemory": 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "0",
"hostPort" : "172.22.0.167:51491",
@@ -149,5 +153,6 @@
"usedOffHeapStorageMemory": 0,
"totalOnHeapStorageMemory": 384093388,
"totalOffHeapStorageMemory": 524288000
- }
+ },
+ "blacklistedInStages" : [ ]
} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json
index 6931fead3d2ff..e87f3e78f2dc8 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json
@@ -19,7 +19,8 @@
"isBlacklisted" : false,
"maxMemory" : 384093388,
"addTime" : "2016-11-15T23:20:38.836GMT",
- "executorLogs" : { }
+ "executorLogs" : { },
+ "blacklistedInStages" : [ ]
}, {
"id" : "3",
"hostPort" : "172.22.0.111:64543",
@@ -44,7 +45,8 @@
"executorLogs" : {
"stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout",
"stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr"
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "2",
"hostPort" : "172.22.0.111:64539",
@@ -69,7 +71,8 @@
"executorLogs" : {
"stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout",
"stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr"
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "1",
"hostPort" : "172.22.0.111:64541",
@@ -94,7 +97,8 @@
"executorLogs" : {
"stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout",
"stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr"
- }
+ },
+ "blacklistedInStages" : [ ]
}, {
"id" : "0",
"hostPort" : "172.22.0.111:64540",
@@ -119,5 +123,6 @@
"executorLogs" : {
"stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout",
"stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr"
- }
+ },
+ "blacklistedInStages" : [ ]
} ]
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 3990ee1ec326d..5d0ffd92647bc 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
System.gc()
assert(ref.get.isEmpty)
- // Getting a garbage collected accum should throw error
- intercept[IllegalStateException] {
- AccumulatorContext.get(accId)
- }
+ // Getting a garbage collected accum should return None.
+ assert(AccumulatorContext.get(accId).isEmpty)
// Getting a normal accumulator. Note: this has to be separate because referencing an
// accumulator above in an `assert` would keep it from being garbage collected.
diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
index 91355f7362900..a5bdc95790722 100644
--- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
+++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala
@@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem {
override def markSupported(): Boolean = wrapped.markSupported()
override def close(): Unit = {
- wrapped.close()
- removeOpenStream(wrapped)
+ try {
+ wrapped.close()
+ } finally {
+ removeOpenStream(wrapped)
+ }
}
override def read(): Int = wrapped.read()
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index e09d5f59817b9..28ea0c6f0bdba 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
val data = sc.parallelize(1 to 1000, 10)
val cachedData = data.persist(storageLevel)
assert(cachedData.count === 1000)
- assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum ===
- storageLevel.replication * data.getNumPartitions)
- assert(cachedData.count === 1000)
- assert(cachedData.count === 1000)
-
+ assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum ===
+ data.getNumPartitions)
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 962945e5b6bb1..896cd2e80aaef 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits {
*/
object DriverWithoutCleanup {
def main(args: Array[String]) {
- Utils.configTestLog4j("INFO")
+ TestUtils.configTestLog4j("INFO")
val conf = new SparkConf
val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf)
sc.parallelize(1 to 100, 4).count()
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index a0cae5a9e011c..3cfb0a9feb32b 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark
import scala.collection.mutable
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.{mock, never, verify, when}
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.apache.spark.executor.TaskMetrics
@@ -26,6 +28,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.ExternalClusterManager
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.local.LocalSchedulerBackend
+import org.apache.spark.storage.BlockManagerMaster
import org.apache.spark.util.ManualClock
/**
@@ -142,6 +145,39 @@ class ExecutorAllocationManagerSuite
assert(numExecutorsToAdd(manager) === 1)
}
+ def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = {
+ val conf = new SparkConf()
+ .setMaster("myDummyLocalExternalClusterManager")
+ .setAppName("test-executor-allocation-manager")
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.dynamicAllocation.testing", "true")
+ .set("spark.dynamicAllocation.maxExecutors", "15")
+ .set("spark.dynamicAllocation.minExecutors", "3")
+ .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString)
+ .set("spark.executor.cores", cores.toString)
+ val sc = new SparkContext(conf)
+ contexts += sc
+ var manager = sc.executorAllocationManager.get
+ post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20)))
+ for (i <- 0 to 5) {
+ addExecutors(manager)
+ }
+ assert(numExecutorsTarget(manager) === expected)
+ sc.stop()
+ }
+
+ test("executionAllocationRatio is correctly handled") {
+ testAllocationRatio(1, 0.5, 10)
+ testAllocationRatio(1, 1.0/3.0, 7)
+ testAllocationRatio(2, 1.0/3.0, 4)
+ testAllocationRatio(1, 0.385, 8)
+
+ // max/min executors capping
+ testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max
+ testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min
+ }
+
+
test("add executors capped by num pending tasks") {
sc = createSparkContext(0, 10, 0)
val manager = sc.executorAllocationManager.get
@@ -1050,6 +1086,66 @@ class ExecutorAllocationManagerSuite
assert(removeTimes(manager) === Map.empty)
}
+ test("SPARK-23365 Don't update target num executors when killing idle executors") {
+ val minExecutors = 1
+ val initialExecutors = 1
+ val maxExecutors = 2
+ val conf = new SparkConf()
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.shuffle.service.enabled", "true")
+ .set("spark.dynamicAllocation.minExecutors", minExecutors.toString)
+ .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString)
+ .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString)
+ .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms")
+ .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms")
+ .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms")
+ val mockAllocationClient = mock(classOf[ExecutorAllocationClient])
+ val mockBMM = mock(classOf[BlockManagerMaster])
+ val manager = new ExecutorAllocationManager(
+ mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM)
+ val clock = new ManualClock()
+ manager.setClock(clock)
+
+ when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true)
+ // test setup -- job with 2 tasks, scale up to two executors
+ assert(numExecutorsTarget(manager) === 1)
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty)))
+ manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2)))
+ clock.advance(1000)
+ manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis())
+ assert(numExecutorsTarget(manager) === 2)
+ val taskInfo0 = createTaskInfo(0, 0, "executor-1")
+ manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0))
+ manager.listener.onExecutorAdded(SparkListenerExecutorAdded(
+ clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty)))
+ val taskInfo1 = createTaskInfo(1, 1, "executor-2")
+ manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1))
+ assert(numExecutorsTarget(manager) === 2)
+
+ // have one task finish -- we should adjust the target number of executors down
+ // but we should *not* kill any executors yet
+ manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null))
+ assert(maxNumExecutorsNeeded(manager) === 1)
+ assert(numExecutorsTarget(manager) === 2)
+ clock.advance(1000)
+ manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis())
+ assert(numExecutorsTarget(manager) === 1)
+ verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any())
+
+ // now we cross the idle timeout for executor-1, so we kill it. the really important
+ // thing here is that we do *not* ask the executor allocation client to adjust the target
+ // number of executors down
+ when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false))
+ .thenReturn(Seq("executor-1"))
+ clock.advance(3000)
+ schedule(manager)
+ assert(maxNumExecutorsNeeded(manager) === 1)
+ assert(numExecutorsTarget(manager) === 1)
+ // here's the important verify -- we did kill the executors, but did not adjust the target count
+ verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false)
+ }
+
private def createSparkContext(
minExecutors: Int = 1,
maxExecutors: Int = 5,
@@ -1268,7 +1364,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend
override def killExecutors(
executorIds: Seq[String],
- replace: Boolean,
+ adjustTargetNumExecutors: Boolean,
+ countFailures: Boolean,
force: Boolean): Seq[String] = executorIds
override def start(): Unit = sb.start()
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 55a9122cf9026..a441b9c8ab97a 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream
import scala.io.Source
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
@@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp
import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
import org.apache.spark.internal.config._
-import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD}
+import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -596,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext {
actualPartitionNum = 5,
expectedPartitionNum = 2)
}
+
+ test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") {
+ // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD,
+ // We test both cases here.
+
+ val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1")
+ val fs = deletedPath.getFileSystem(new Configuration())
+ fs.delete(deletedPath, true)
+ intercept[FileNotFoundException](fs.open(deletedPath))
+
+ def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = {
+ val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2")
+ val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString)))
+ writer.write("hello\n")
+ writer.write("world\n")
+ writer.close()
+ val rdd = if (newApi) {
+ sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat],
+ classOf[LongWritable], classOf[Text])
+ } else {
+ sc.textFile(dataPath.toString)
+ }
+ rdd.partitions
+ fs.delete(dataPath, true)
+ // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute
+ // because partitions' info already cached.
+ rdd.collect()
+ }
+
+ // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default.
+ sc = new SparkContext("local", "test")
+ intercept[org.apache.hadoop.mapred.InvalidInputException] {
+ // Exception happens when HadoopRDD.getPartitions
+ sc.textFile(deletedPath.toString).collect()
+ }
+
+ var e = intercept[SparkException] {
+ collectRDDAndDeleteFileBeforeCompute(false)
+ }
+ assert(e.getCause.isInstanceOf[java.io.FileNotFoundException])
+
+ intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] {
+ // Exception happens when NewHadoopRDD.getPartitions
+ sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat],
+ classOf[LongWritable], classOf[Text]).collect
+ }
+
+ e = intercept[SparkException] {
+ collectRDDAndDeleteFileBeforeCompute(true)
+ }
+ assert(e.getCause.isInstanceOf[java.io.FileNotFoundException])
+
+ sc.stop()
+
+ // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true.
+ val conf = new SparkConf().set(IGNORE_MISSING_FILES, true)
+ sc = new SparkContext("local", "test", conf)
+ assert(sc.textFile(deletedPath.toString).collect().isEmpty)
+
+ assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty)
+
+ assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat],
+ classOf[LongWritable], classOf[Text]).collect().isEmpty)
+
+ assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index 8d7be77f51fe9..62824a5bec9d1 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
// This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
// ID that's < 2 * numPartitions belongs to the first attempt of this stage.
val taskContext = TaskContext.get()
- val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
+ val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L
if (isFirstStageAttempt) {
throw new FetchFailedException(
SparkEnv.get.blockManager.blockManagerId,
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 8a77aea75a992..61da4138896cd 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.concurrent.Semaphore
+import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
@@ -26,7 +27,7 @@ import scala.concurrent.duration._
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils
/**
@@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
override def afterEach() {
try {
resetSparkContext()
+ JobCancellationSuite.taskStartedSemaphore.drainPermits()
+ JobCancellationSuite.taskCancelledSemaphore.drainPermits()
+ JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits()
+ JobCancellationSuite.executionOfInterruptibleCounter.set(0)
} finally {
super.afterEach()
}
@@ -320,6 +325,67 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
f2.get()
}
+ test("interruptible iterator of shuffle reader") {
+ // In this test case, we create a Spark job of two stages. The second stage is cancelled during
+ // execution and a counter is used to make sure that the corresponding tasks are indeed
+ // cancelled.
+ import JobCancellationSuite._
+ sc = new SparkContext("local[2]", "test interruptible iterator")
+
+ // Increase the number of elements to be proceeded to avoid this test being flaky.
+ val numElements = 10000
+ val taskCompletedSem = new Semaphore(0)
+
+ sc.addSparkListener(new SparkListener {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ // release taskCancelledSemaphore when cancelTasks event has been posted
+ if (stageCompleted.stageInfo.stageId == 1) {
+ taskCancelledSemaphore.release(numElements)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ if (taskEnd.stageId == 1) { // make sure tasks are completed
+ taskCompletedSem.release()
+ }
+ }
+ })
+
+ // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be
+ // interrupted by `InterruptibleIterator`.
+ sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+
+ val f = sc.parallelize(1 to numElements).map { i => (i, i) }
+ .repartitionAndSortWithinPartitions(new HashPartitioner(1))
+ .mapPartitions { iter =>
+ taskStartedSemaphore.release()
+ iter
+ }.foreachAsync { x =>
+ // Block this code from being executed, until the job get cancelled. In this case, if the
+ // source iterator is interruptible, the max number of increment should be under
+ // `numElements`.
+ taskCancelledSemaphore.acquire()
+ executionOfInterruptibleCounter.getAndIncrement()
+ }
+
+ taskStartedSemaphore.acquire()
+ // Job is cancelled when:
+ // 1. task in reduce stage has been started, guaranteed by previous line.
+ // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until
+ // JobCancelled event is posted.
+ // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus
+ // partial of the inputs should not get processed (It's very unlikely that Spark can process
+ // 10000 elements between JobCancelled is posted and task is really killed).
+ f.cancel()
+
+ val e = intercept[SparkException](f.get()).getCause
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+
+ // Make sure tasks are indeed completed.
+ taskCompletedSem.acquire()
+ assert(executionOfInterruptibleCounter.get() < numElements)
+ }
+
def testCount() {
// Cancel before launching any tasks
{
@@ -381,7 +447,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
object JobCancellationSuite {
+ // To avoid any headaches, reset these global variables in the companion class's afterEach block
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
val twoJobsSharingStageSemaphore = new Semaphore(0)
+ val executionOfInterruptibleCounter = new AtomicInteger(0)
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 50b8ea754d8d9..21f481d477242 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L)))
slaveTracker.updateEpoch(masterTracker.getEpoch)
- assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
@@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite {
}
}
+ test("zero-sized blocks should be excluded when getMapSizesByExecutorId") {
+ val rpcEnv = createRpcEnv("test")
+ val tracker = newTrackerMaster()
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+ tracker.registerShuffle(10, 2)
+
+ val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L))
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
+ tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
+ Array(size0, size1000, size0, size10000)))
+ tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
+ Array(size10000, size0, size1000, size0)))
+ assert(tracker.containsShuffle(10))
+ assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
+ Seq(
+ (BlockManagerId("a", "hostA", 1000),
+ Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))),
+ (BlockManagerId("b", "hostB", 1000),
+ Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000)))
+ )
+ )
+
+ tracker.unregisterShuffle(10)
+ tracker.stop()
+ rpcEnv.shutdown()
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
deleted file mode 100644
index 33270bec6247c..0000000000000
--- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import java.io.File
-
-object SSLSampleConfigs {
- val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath
- val untrustedKeyStorePath = new File(
- this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath
- val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath
-
- val enabledAlgorithms =
- // A reasonable set of TLSv1.2 Oracle security provider suites
- "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
- "TLS_RSA_WITH_AES_256_CBC_SHA256, " +
- "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
- "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
- "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " +
- // and their equivalent names in the IBM Security provider
- "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " +
- "SSL_RSA_WITH_AES_256_CBC_SHA256, " +
- "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " +
- "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " +
- "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256"
-
- def sparkSSLConfig(): SparkConf = {
- val conf = new SparkConf(loadDefaults = false)
- conf.set("spark.ssl.enabled", "true")
- conf.set("spark.ssl.keyStore", keyStorePath)
- conf.set("spark.ssl.keyStorePassword", "password")
- conf.set("spark.ssl.keyPassword", "password")
- conf.set("spark.ssl.trustStore", trustStorePath)
- conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
- conf.set("spark.ssl.protocol", "TLSv1.2")
- conf
- }
-
- def sparkSSLConfigUntrusted(): SparkConf = {
- val conf = new SparkConf(loadDefaults = false)
- conf.set("spark.ssl.enabled", "true")
- conf.set("spark.ssl.keyStore", untrustedKeyStorePath)
- conf.set("spark.ssl.keyStorePassword", "password")
- conf.set("spark.ssl.keyPassword", "password")
- conf.set("spark.ssl.trustStore", trustStorePath)
- conf.set("spark.ssl.trustStorePassword", "password")
- conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms)
- conf.set("spark.ssl.protocol", "TLSv1.2")
- conf
- }
-
-}
diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
index cf59265dd646d..e357299770a2e 100644
--- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
@@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties {
assert(securityManager.checkModifyPermissions("user1") === false)
}
- test("ssl on setup") {
- val conf = SSLSampleConfigs.sparkSSLConfig()
- val expectedAlgorithms = Set(
- "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
- "TLS_RSA_WITH_AES_256_CBC_SHA256",
- "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
- "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
- "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
- "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
- "SSL_RSA_WITH_AES_256_CBC_SHA256",
- "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256",
- "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
- "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256")
-
- val securityManager = new SecurityManager(conf)
-
- assert(securityManager.fileServerSSLOptions.enabled === true)
-
- assert(securityManager.sslSocketFactory.isDefined === true)
- assert(securityManager.hostnameVerifier.isDefined === true)
-
- assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true)
- assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore")
- assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true)
- assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore")
- assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password"))
- assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password"))
- assert(securityManager.fileServerSSLOptions.keyPassword === Some("password"))
- assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2"))
- assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms)
- }
-
- test("ssl off setup") {
- val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir())
-
- System.setProperty("spark.ssl.configFile", file.getAbsolutePath)
- val conf = new SparkConf()
-
- val securityManager = new SecurityManager(conf)
-
- assert(securityManager.fileServerSSLOptions.enabled === false)
- assert(securityManager.sslSocketFactory.isDefined === false)
- assert(securityManager.hostnameVerifier.isDefined === false)
- }
-
test("missing secret authentication key") {
val conf = new SparkConf().set("spark.authenticate", "true")
val mgr = new SecurityManager(conf)
@@ -440,23 +395,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties {
assert(keyFromEnv === new SecurityManager(conf2).getSecretKey())
}
- test("secret key generation in yarn mode") {
- val conf = new SparkConf()
- .set(NETWORK_AUTH_ENABLED, true)
- .set(SparkLauncher.SPARK_MASTER, "yarn")
- val mgr = new SecurityManager(conf)
-
- UserGroupInformation.createUserForTesting("authTest", Array()).doAs(
- new PrivilegedExceptionAction[Unit]() {
- override def run(): Unit = {
- mgr.initializeAuth()
- val creds = UserGroupInformation.getCurrentUser().getCredentials()
- val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY)
- assert(secret != null)
- assert(new String(secret, UTF_8) === mgr.getSecretKey())
+ test("secret key generation") {
+ Seq(
+ ("yarn", true),
+ ("local", true),
+ ("local[*]", true),
+ ("local[1, 2]", true),
+ ("local-cluster[2, 1, 1024]", false),
+ ("invalid", false)
+ ).foreach { case (master, shouldGenerateSecret) =>
+ val conf = new SparkConf()
+ .set(NETWORK_AUTH_ENABLED, true)
+ .set(SparkLauncher.SPARK_MASTER, master)
+ val mgr = new SecurityManager(conf)
+
+ UserGroupInformation.createUserForTesting("authTest", Array()).doAs(
+ new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ if (shouldGenerateSecret) {
+ mgr.initializeAuth()
+ val creds = UserGroupInformation.getCurrentUser().getCredentials()
+ val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY)
+ assert(secret != null)
+ assert(new String(secret, UTF_8) === mgr.getSecretKey())
+ } else {
+ intercept[IllegalArgumentException] {
+ mgr.initializeAuth()
+ }
+ intercept[IllegalArgumentException] {
+ mgr.getSecretKey()
+ }
+ }
+ }
}
- }
- )
+ )
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index bff808eb540ac..0d06b02e74e34 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst
}
}
+ val defaultIllegalValue = "SomeIllegalValue"
+ val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map(
+ "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)),
+ "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)),
+ "getTimeAsMs" -> (_.getTimeAsMs(_)),
+ "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)),
+ "getSizeAsBytes" -> (_.getSizeAsBytes(_)),
+ "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)),
+ "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)),
+ "getSizeAsKb" -> (_.getSizeAsKb(_)),
+ "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)),
+ "getSizeAsMb" -> (_.getSizeAsMb(_)),
+ "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)),
+ "getSizeAsGb" -> (_.getSizeAsGb(_)),
+ "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)),
+ "getInt" -> (_.getInt(_, 0)),
+ "getLong" -> (_.getLong(_, 0L)),
+ "getDouble" -> (_.getDouble(_, 0.0)),
+ "getBoolean" -> (_.getBoolean(_, false))
+ )
+
+ illegalValueTests.foreach { case (name, getValue) =>
+ test(s"SPARK-24337: $name throws an useful error message with key name") {
+ val key = "SomeKey"
+ val conf = new SparkConf()
+ conf.set(key, "SomeInvalidValue")
+ val thrown = intercept[IllegalArgumentException] {
+ getValue(conf, key)
+ }
+ assert(thrown.getMessage.contains(key))
+ }
+ }
}
class Class1 {}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index b30bd74812b36..ce9f2be1c02dd 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import java.io.File
import java.net.{MalformedURLException, URI}
import java.nio.charset.StandardCharsets
-import java.util.concurrent.{Semaphore, TimeUnit}
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
import scala.concurrent.duration._
@@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
test("Cancelling stages/jobs with custom reasons.") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
+ sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")
val REASON = "You shall not pass"
- val slices = 10
- val listener = new SparkListener {
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
- if (SparkContextSuite.cancelStage) {
- eventually(timeout(10.seconds)) {
- assert(SparkContextSuite.isTaskStarted)
+ for (cancelWhat <- Seq("stage", "job")) {
+ // This countdown latch used to make sure stage or job canceled in listener
+ val latch = new CountDownLatch(1)
+
+ val listener = cancelWhat match {
+ case "stage" =>
+ new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+ sc.cancelStage(taskStart.stageId, REASON)
+ latch.countDown()
+ }
}
- sc.cancelStage(taskStart.stageId, REASON)
- SparkContextSuite.cancelStage = false
- SparkContextSuite.semaphore.release(slices)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
- if (SparkContextSuite.cancelJob) {
- eventually(timeout(10.seconds)) {
- assert(SparkContextSuite.isTaskStarted)
+ case "job" =>
+ new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ sc.cancelJob(jobStart.jobId, REASON)
+ latch.countDown()
+ }
}
- sc.cancelJob(jobStart.jobId, REASON)
- SparkContextSuite.cancelJob = false
- SparkContextSuite.semaphore.release(slices)
- }
}
- }
- sc.addSparkListener(listener)
-
- for (cancelWhat <- Seq("stage", "job")) {
- SparkContextSuite.semaphore.drainPermits()
- SparkContextSuite.isTaskStarted = false
- SparkContextSuite.cancelStage = (cancelWhat == "stage")
- SparkContextSuite.cancelJob = (cancelWhat == "job")
+ sc.addSparkListener(listener)
val ex = intercept[SparkException] {
- sc.range(0, 10000L, numSlices = slices).mapPartitions { x =>
- SparkContextSuite.isTaskStarted = true
- // Block waiting for the listener to cancel the stage or job.
- SparkContextSuite.semaphore.acquire()
+ sc.range(0, 10000L, numSlices = 10).mapPartitions { x =>
+ x.synchronized {
+ x.wait()
+ }
x
}.count()
}
@@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.")
}
+ latch.await(20, TimeUnit.SECONDS)
eventually(timeout(20.seconds)) {
assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0)
}
+ sc.removeSparkListener(listener)
}
}
@@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
object SparkContextSuite {
- @volatile var cancelJob = false
- @volatile var cancelStage = false
@volatile var isTaskStarted = false
@volatile var taskKilled = false
@volatile var taskSucceeded = false
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 3af9d82393bc4..31289026b0027 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -59,6 +59,7 @@ abstract class SparkFunSuite
protected val enableAutoThreadAudit = true
protected override def beforeAll(): Unit = {
+ System.setProperty("spark.testing", "true")
if (enableAutoThreadAudit) {
doThreadPreAudit()
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
index 32dd3ecc2f027..ef947eb074647 100644
--- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
@@ -66,7 +66,6 @@ class RPackageUtilsSuite
override def beforeEach(): Unit = {
super.beforeEach()
- System.setProperty("spark.testing", "true")
lineBuffer.clear()
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 27dd435332348..545c8d0423dc3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.deploy
import java.io._
import java.net.URI
import java.nio.charset.StandardCharsets
-import java.nio.file.Files
+import java.nio.file.{Files, Paths}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -35,12 +35,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.TestUtils
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.scheduler.EventLoggingListener
import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils}
@@ -105,9 +107,13 @@ class SparkSubmitSuite
// Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x
implicit val defaultSignaler: Signaler = ThreadSignaler
+ private val emptyIvySettings = File.createTempFile("ivy", ".xml")
+ FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8)
+
+ private val submit = new SparkSubmit()
+
override def beforeEach() {
super.beforeEach()
- System.setProperty("spark.testing", "true")
}
// scalastyle:off println
@@ -125,13 +131,16 @@ class SparkSubmitSuite
}
test("handle binary specified but not class") {
- testPrematureExit(Array("foo.jar"), "No main class")
+ val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
+ testPrematureExit(Array(jar.toString()), "No main class")
}
test("handles arguments with --key=val") {
val clArgs = Seq(
"--jars=one.jar,two.jar,three.jar",
- "--name=myApp")
+ "--name=myApp",
+ "--class=org.FooBar",
+ SparkLauncher.NO_RESOURCE)
val appArgs = new SparkSubmitArguments(clArgs)
appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar")
appArgs.name should be ("myApp")
@@ -171,6 +180,26 @@ class SparkSubmitSuite
appArgs.toString should include ("thequeue")
}
+ test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") {
+ val clArgs1 = Seq(
+ "--name", "myApp",
+ "--class", "Foo",
+ "--num-executors", "0",
+ "--conf", "spark.dynamicAllocation.enabled=true",
+ "thejar.jar")
+ new SparkSubmitArguments(clArgs1)
+
+ val clArgs2 = Seq(
+ "--name", "myApp",
+ "--class", "Foo",
+ "--num-executors", "0",
+ "--conf", "spark.dynamicAllocation.enabled=false",
+ "thejar.jar")
+
+ val e = intercept[SparkException](new SparkSubmitArguments(clArgs2))
+ assert(e.getMessage.contains("Number of executors must be a positive number"))
+ }
+
test("specify deploy mode through configuration") {
val clArgs = Seq(
"--master", "yarn",
@@ -179,7 +208,7 @@ class SparkSubmitSuite
"thejar.jar"
)
val appArgs = new SparkSubmitArguments(clArgs)
- val (_, _, conf, _) = prepareSubmitEnvironment(appArgs)
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs)
appArgs.deployMode should be ("client")
conf.get("spark.submit.deployMode") should be ("client")
@@ -189,11 +218,11 @@ class SparkSubmitSuite
"--master", "yarn",
"--deploy-mode", "cluster",
"--conf", "spark.submit.deployMode=client",
- "-class", "org.SomeClass",
+ "--class", "org.SomeClass",
"thejar.jar"
)
val appArgs1 = new SparkSubmitArguments(clArgs1)
- val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1)
+ val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1)
appArgs1.deployMode should be ("cluster")
conf1.get("spark.submit.deployMode") should be ("cluster")
@@ -207,7 +236,7 @@ class SparkSubmitSuite
val appArgs2 = new SparkSubmitArguments(clArgs2)
appArgs2.deployMode should be (null)
- val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2)
+ val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2)
appArgs2.deployMode should be ("client")
conf2.get("spark.submit.deployMode") should be ("client")
}
@@ -230,7 +259,7 @@ class SparkSubmitSuite
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
childArgsStr should include ("--class org.SomeClass")
childArgsStr should include ("--arg arg1 --arg arg2")
@@ -273,7 +302,7 @@ class SparkSubmitSuite
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (4)
@@ -319,7 +348,7 @@ class SparkSubmitSuite
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
appArgs.useRest = useRest
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
if (useRest) {
childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2")
@@ -356,7 +385,7 @@ class SparkSubmitSuite
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (1)
@@ -378,7 +407,7 @@ class SparkSubmitSuite
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
childArgs.mkString(" ") should be ("arg1 arg2")
mainClass should be ("org.SomeClass")
classpath should have length (1)
@@ -400,7 +429,7 @@ class SparkSubmitSuite
"/home/thejar.jar",
"arg1")
val appArgs = new SparkSubmitArguments(clArgs)
- val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap
childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar"))
@@ -425,7 +454,7 @@ class SparkSubmitSuite
"thejar.jar",
"arg1", "arg2")
val appArgs = new SparkSubmitArguments(clArgs)
- val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs)
+ val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
conf.get("spark.executor.memory") should be ("5g")
conf.get("spark.master") should be ("yarn")
conf.get("spark.submit.deployMode") should be ("cluster")
@@ -438,12 +467,12 @@ class SparkSubmitSuite
val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell")
val appArgs1 = new SparkSubmitArguments(clArgs1)
- val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1)
+ val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1)
conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true)
val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar")
val appArgs2 = new SparkSubmitArguments(clArgs2)
- val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2)
+ val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2)
assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS))
}
@@ -520,6 +549,7 @@ class SparkSubmitSuite
"--repositories", repo,
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -530,7 +560,6 @@ class SparkSubmitSuite
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val dep = MavenCoordinate("my.great.dep", "mylib", "0.1")
- // Test using "spark.jars.packages" and "spark.jars.repositories" configurations.
IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo =>
val args = Seq(
"--class", JarCreationTest.getClass.getName.stripSuffix("$"),
@@ -540,6 +569,7 @@ class SparkSubmitSuite
"--conf", s"spark.jars.repositories=$repo",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
unusedJar.toString,
"my.great.lib.MyLib", "my.great.dep.MyLib")
runSparkSubmit(args)
@@ -550,7 +580,6 @@ class SparkSubmitSuite
// See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log
ignore("correctly builds R packages included in a jar with --packages") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -563,6 +592,7 @@ class SparkSubmitSuite
"--master", "local-cluster[2,1,1024]",
"--packages", main.toString,
"--repositories", repo,
+ "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}",
"--verbose",
"--conf", "spark.ui.enabled=false",
rScriptDir)
@@ -573,7 +603,6 @@ class SparkSubmitSuite
test("include an external JAR in SparkR") {
assume(RUtils.isRInstalled, "R isn't installed on this machine.")
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- // Check if the SparkR package is installed
assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.")
val rScriptDir =
Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator)
@@ -606,10 +635,13 @@ class SparkSubmitSuite
}
test("resolves command line argument paths correctly") {
- val jars = "/jar1,/jar2" // --jars
- val files = "local:/file1,file2" // --files
- val archives = "file:/archive1,archive2" // --archives
- val pyFiles = "py-file1,py-file2" // --py-files
+ val dir = Utils.createTempDir()
+ val archive = Paths.get(dir.toPath.toString, "single.zip")
+ Files.createFile(archive)
+ val jars = "/jar1,/jar2"
+ val files = "local:/file1,file2"
+ val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3"
+ val pyFiles = "py-file1,py-file2"
// Test jars and files
val clArgs = Seq(
@@ -619,7 +651,7 @@ class SparkSubmitSuite
"--files", files,
"thejar.jar")
val appArgs = new SparkSubmitArguments(clArgs)
- val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs)
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs)
appArgs.jars should be (Utils.resolveURIs(jars))
appArgs.files should be (Utils.resolveURIs(files))
conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar"))
@@ -634,11 +666,12 @@ class SparkSubmitSuite
"thejar.jar"
)
val appArgs2 = new SparkSubmitArguments(clArgs2)
- val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2)
+ val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2)
appArgs2.files should be (Utils.resolveURIs(files))
- appArgs2.archives should be (Utils.resolveURIs(archives))
+ appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3")
conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files))
- conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives))
+ conf2.get("spark.yarn.dist.archives") should fullyMatch regex
+ ("file:/archive1,file:.*#archive3")
// Test python files
val clArgs3 = Seq(
@@ -649,7 +682,7 @@ class SparkSubmitSuite
"mister.py"
)
val appArgs3 = new SparkSubmitArguments(clArgs3)
- val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3)
+ val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3)
appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles))
conf3.get("spark.submit.pyFiles") should be (
PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
@@ -657,6 +690,29 @@ class SparkSubmitSuite
conf3.get(PYSPARK_PYTHON.key) should be ("python3.5")
}
+ test("ambiguous archive mapping results in error message") {
+ val dir = Utils.createTempDir()
+ val archive1 = Paths.get(dir.toPath.toString, "first.zip")
+ val archive2 = Paths.get(dir.toPath.toString, "second.zip")
+ Files.createFile(archive1)
+ Files.createFile(archive2)
+ val jars = "/jar1,/jar2"
+ val files = "local:/file1,file2"
+ val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3"
+ val pyFiles = "py-file1,py-file2"
+
+ // Test files and archives (Yarn)
+ val clArgs2 = Seq(
+ "--master", "yarn",
+ "--class", "org.SomeClass",
+ "--files", files,
+ "--archives", archives,
+ "thejar.jar"
+ )
+
+ testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files")
+ }
+
test("resolves config paths correctly") {
val jars = "/jar1,/jar2" // spark.jars
val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files
@@ -678,7 +734,7 @@ class SparkSubmitSuite
"thejar.jar"
)
val appArgs = new SparkSubmitArguments(clArgs)
- val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs)
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs)
conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar"))
conf.get("spark.files") should be(Utils.resolveURIs(files))
@@ -695,7 +751,7 @@ class SparkSubmitSuite
"thejar.jar"
)
val appArgs2 = new SparkSubmitArguments(clArgs2)
- val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2)
+ val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2)
conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files))
conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives))
@@ -710,14 +766,18 @@ class SparkSubmitSuite
"mister.py"
)
val appArgs3 = new SparkSubmitArguments(clArgs3)
- val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3)
+ val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3)
conf3.get("spark.submit.pyFiles") should be(
PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
// Test remote python files
+ val hadoopConf = new Configuration()
+ updateConfWithFakeS3Fs(hadoopConf)
val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir)
+ val pyFile1 = File.createTempFile("file1", ".py", tmpDir)
+ val pyFile2 = File.createTempFile("file2", ".py", tmpDir)
val writer4 = new PrintWriter(f4)
- val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py"
+ val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}"
writer4.println("spark.submit.pyFiles " + remotePyFiles)
writer4.close()
val clArgs4 = Seq(
@@ -727,7 +787,7 @@ class SparkSubmitSuite
"hdfs:///tmp/mister.py"
)
val appArgs4 = new SparkSubmitArguments(clArgs4)
- val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4)
+ val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf))
// Should not format python path for yarn cluster mode
conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles))
}
@@ -748,32 +808,20 @@ class SparkSubmitSuite
}
test("SPARK_CONF_DIR overrides spark-defaults.conf") {
- forConfDir(Map("spark.executor.memory" -> "2.3g")) { path =>
+ forConfDir(Map("spark.executor.memory" -> "3g")) { path =>
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
"--master", "local",
unusedJar.toString)
- val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path))
+ val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path))
assert(appArgs.propertiesFile != null)
assert(appArgs.propertiesFile.startsWith(path))
- appArgs.executorMemory should be ("2.3g")
+ appArgs.executorMemory should be ("3g")
}
}
- test("comma separated list of files are unioned correctly") {
- val left = Option("/tmp/a.jar,/tmp/b.jar")
- val right = Option("/tmp/c.jar,/tmp/a.jar")
- val emptyString = Option("")
- Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar"))
- Utils.unionFileLists(emptyString, emptyString) should be (Set.empty)
- Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar"))
- Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar"))
- Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar"))
- Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar"))
- }
-
test("support glob path") {
val tmpJarDir = Utils.createTempDir()
val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir)
@@ -791,6 +839,9 @@ class SparkSubmitSuite
val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir)
val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir)
+ val tempPyFile = File.createTempFile("tmpApp", ".py")
+ tempPyFile.deleteOnExit()
+
val args = Seq(
"--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
"--name", "testApp",
@@ -800,10 +851,10 @@ class SparkSubmitSuite
"--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*",
"--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*",
"--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip",
- jar2.toString)
+ tempPyFile.toURI().toString())
val appArgs = new SparkSubmitArguments(args)
- val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs)
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs)
conf.get("spark.yarn.dist.jars").split(",").toSet should be
(Set(jar1.toURI.toString, jar2.toURI.toString))
conf.get("spark.yarn.dist.files").split(",").toSet should be
@@ -929,7 +980,7 @@ class SparkSubmitSuite
)
val appArgs = new SparkSubmitArguments(args)
- val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf))
// All the resources should still be remote paths, so that YARN client will not upload again.
conf.get("spark.yarn.dist.jars") should be (tmpJarPath)
@@ -944,25 +995,28 @@ class SparkSubmitSuite
}
test("download remote resource if it is not supported by yarn service") {
- testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false)
+ testRemoteResources(enableHttpFs = false, blacklistHttpFs = false)
}
test("avoid downloading remote resource if it is supported by yarn service") {
- testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true)
+ testRemoteResources(enableHttpFs = true, blacklistHttpFs = false)
}
test("force download from blacklisted schemes") {
- testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true)
+ testRemoteResources(enableHttpFs = true, blacklistHttpFs = true)
}
- private def testRemoteResources(isHttpSchemeBlacklisted: Boolean,
- supportMockHttpFs: Boolean): Unit = {
+ private def testRemoteResources(
+ enableHttpFs: Boolean,
+ blacklistHttpFs: Boolean): Unit = {
val hadoopConf = new Configuration()
updateConfWithFakeS3Fs(hadoopConf)
- if (supportMockHttpFs) {
+ if (enableHttpFs) {
hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName)
- hadoopConf.set("fs.http.impl.disable.cache", "true")
+ } else {
+ hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist")
}
+ hadoopConf.set("fs.http.impl.disable.cache", "true")
val tmpDir = Utils.createTempDir()
val mainResource = File.createTempFile("tmpPy", ".py", tmpDir)
@@ -971,30 +1025,29 @@ class SparkSubmitSuite
val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir)
val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}"
+ val forceDownloadArgs = if (blacklistHttpFs) {
+ Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http")
+ } else {
+ Nil
+ }
+
val args = Seq(
"--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
"--name", "testApp",
"--master", "yarn",
"--deploy-mode", "client",
- "--jars", s"$tmpS3JarPath,$tmpHttpJarPath",
- s"s3a://$mainResource"
- ) ++ (
- if (isHttpSchemeBlacklisted) {
- Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https")
- } else {
- Nil
- }
- )
+ "--jars", s"$tmpS3JarPath,$tmpHttpJarPath"
+ ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource")
val appArgs = new SparkSubmitArguments(args)
- val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf))
val jars = conf.get("spark.yarn.dist.jars").split(",").toSet
// The URI of remote S3 resource should still be remote.
assert(jars.contains(tmpS3JarPath))
- if (supportMockHttpFs) {
+ if (enableHttpFs && !blacklistHttpFs) {
// If Http FS is supported by yarn service, the URI of remote http resource should
// still be remote.
assert(jars.contains(tmpHttpJarPath))
@@ -1038,11 +1091,50 @@ class SparkSubmitSuite
"hello")
val exception = intercept[SparkException] {
- SparkSubmit.main(args)
+ submit.doSubmit(args)
}
assert(exception.getMessage() === "hello")
}
+
+ test("support --py-files/spark.submit.pyFiles in non pyspark application") {
+ val hadoopConf = new Configuration()
+ updateConfWithFakeS3Fs(hadoopConf)
+
+ val tmpDir = Utils.createTempDir()
+ val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir)
+
+ val args = Seq(
+ "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
+ "--name", "testApp",
+ "--master", "yarn",
+ "--deploy-mode", "client",
+ "--py-files", s"s3a://${pyFile.getAbsolutePath}",
+ "spark-internal"
+ )
+
+ val appArgs = new SparkSubmitArguments(args)
+ val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf))
+
+ conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}")
+ conf.get("spark.submit.pyFiles") should (startWith("/"))
+
+ // Verify "spark.submit.pyFiles"
+ val args1 = Seq(
+ "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
+ "--name", "testApp",
+ "--master", "yarn",
+ "--deploy-mode", "client",
+ "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}",
+ "spark-internal"
+ )
+
+ val appArgs1 = new SparkSubmitArguments(args1)
+ val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf))
+
+ conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}")
+ conf1.get("spark.submit.pyFiles") should (startWith("/"))
+ }
}
object SparkSubmitSuite extends SparkFunSuite with TimeLimits {
@@ -1077,7 +1169,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits {
object JarCreationTest extends Logging {
def main(args: Array[String]) {
- Utils.configTestLog4j("INFO")
+ TestUtils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
@@ -1101,7 +1193,7 @@ object JarCreationTest extends Logging {
object SimpleApplicationTest {
def main(args: Array[String]) {
- Utils.configTestLog4j("INFO")
+ TestUtils.configTestLog4j("INFO")
val conf = new SparkConf()
val sc = new SparkContext(conf)
val configs = Seq("spark.master", "spark.app.name")
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index eb8c203ae7751..a0f09891787e0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(jarPath.indexOf("mydep") >= 0, "should find dependency")
}
}
+
+ test("SPARK-10878: test resolution files cleaned after resolving artifact") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1")
+
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath))
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ ivySettings,
+ isTest = true)
+ val r = """.*org.apache.spark-spark-submit-parent-.*""".r
+ assert(!ivySettings.getDefaultCache.listFiles.map(_.getName)
+ .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned")
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index bf7480d79f8a1..27cc47496c805 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite
syncExecutors(sc)
sc.schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.killExecutors(Seq(executorId), replace = false, force)
+ b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false,
+ force)
case _ => fail("expected coarse grained scheduler")
}
}
@@ -610,7 +611,7 @@ class StandaloneDynamicAllocationSuite
* we submit a request to kill them. This must be called before each kill request.
*/
private def syncExecutors(sc: SparkContext): Unit = {
- val driverExecutors = sc.getExecutorStorageStatus
+ val driverExecutors = sc.env.blockManager.master.getStorageStatus
.map(_.blockManagerId.executorId)
.filter { _ != SparkContext.DRIVER_IDENTIFIER}
val masterExecutors = getExecutorIds(sc)
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index fde5f25bce456..77b239489d489 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hdfs.DistributedFileSystem
import org.json4s.jackson.JsonMethods._
import org.mockito.Matchers.any
-import org.mockito.Mockito.{doReturn, mock, spy, verify}
+import org.mockito.Mockito.{mock, spy, verify}
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._
@@ -151,8 +151,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
var mergeApplicationListingCall = 0
override protected def mergeApplicationListing(
fileStatus: FileStatus,
- lastSeen: Long): Unit = {
- super.mergeApplicationListing(fileStatus, lastSeen)
+ lastSeen: Long,
+ enableSkipToEnd: Boolean): Unit = {
+ super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd)
mergeApplicationListingCall += 1
}
}
@@ -256,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
)
updateAndCheck(provider) { list =>
- list should not be (null)
list.size should be (1)
list.head.attempts.size should be (3)
list.head.attempts.head.attemptId should be (Some("attempt3"))
}
val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false)
- writeFile(attempt1, true, None,
+ writeFile(app2Attempt1, true, None,
SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")),
SparkListenerApplicationEnd(6L)
)
@@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false)
writeFile(log, true, None,
SparkListenerApplicationStart(
- "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")),
- SparkListenerApplicationEnd(5001 * i)
+ "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")),
+ SparkListenerApplicationEnd(5001L * i)
)
log
}
@@ -649,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
// Add more info to the app log, and trigger the provider to update things.
writeFile(appLog, true, None,
SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None),
- SparkListenerJobStart(0, 1L, Nil, null),
- SparkListenerApplicationEnd(5L)
+ SparkListenerJobStart(0, 1L, Nil, null)
)
provider.checkForLogs()
@@ -668,11 +667,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
test("clean up stale app information") {
val storeDir = Utils.createTempDir()
val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath())
- val provider = spy(new FsHistoryProvider(conf))
+ val clock = new ManualClock()
+ val provider = spy(new FsHistoryProvider(conf, clock))
val appId = "new1"
// Write logs for two app attempts.
- doReturn(1L).when(provider).getNewLastScanTime()
+ clock.advance(1)
val attempt1 = newLogFile(appId, Some("1"), inProgress = false)
writeFile(attempt1, true, None,
SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")),
@@ -697,7 +697,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
// Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since
// attempt 2 still exists, listing data should be there.
- doReturn(2L).when(provider).getNewLastScanTime()
+ clock.advance(1)
attempt1.delete()
updateAndCheck(provider) { list =>
assert(list.size === 1)
@@ -708,7 +708,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
assert(provider.getAppUI(appId, None) === None)
// Delete the second attempt's log file. Now everything should go away.
- doReturn(3L).when(provider).getNewLastScanTime()
+ clock.advance(1)
attempt2.delete()
updateAndCheck(provider) { list =>
assert(list.isEmpty)
@@ -718,9 +718,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
test("SPARK-21571: clean up removes invalid history files") {
val clock = new ManualClock()
val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d")
- val provider = new FsHistoryProvider(conf, clock) {
- override def getNewLastScanTime(): Long = clock.getTimeMillis()
- }
+ val provider = new FsHistoryProvider(conf, clock)
// Create 0-byte size inprogress and complete files
var logCount = 0
@@ -772,6 +770,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
assert(new File(testDir.toURI).listFiles().size === validLogCount)
}
+ test("always find end event for finished apps") {
+ // Create a log file where the end event is before the configure chunk to be reparsed at
+ // the end of the file. The correct listing should still be generated.
+ val log = newLogFile("end-event-test", None, inProgress = false)
+ writeFile(log, true, None,
+ Seq(
+ SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None),
+ SparkListenerEnvironmentUpdate(Map(
+ "Spark Properties" -> Seq.empty,
+ "JVM Information" -> Seq.empty,
+ "System Properties" -> Seq.empty,
+ "Classpath Entries" -> Seq.empty
+ )),
+ SparkListenerApplicationEnd(5L)
+ ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*)
+
+ val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k")
+ val provider = new FsHistoryProvider(conf)
+ updateAndCheck(provider) { list =>
+ assert(list.size === 1)
+ assert(list(0).attempts.size === 1)
+ assert(list(0).attempts(0).completed)
+ }
+ }
+
+ test("parse event logs with optimizations off") {
+ val conf = createTestConf()
+ .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L)
+ .set(FAST_IN_PROGRESS_PARSING, false)
+ val provider = new FsHistoryProvider(conf)
+
+ val complete = newLogFile("complete", None, inProgress = false)
+ writeFile(complete, true, None,
+ SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None),
+ SparkListenerApplicationEnd(5L)
+ )
+
+ val incomplete = newLogFile("incomplete", None, inProgress = true)
+ writeFile(incomplete, true, None,
+ SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None)
+ )
+
+ updateAndCheck(provider) { list =>
+ list.size should be (2)
+ list.count(_.attempts.head.completed) should be (1)
+ }
+ }
+
/**
* Asks the provider to check for logs and calls a function to perform checks on the updated
* app list. Example:
@@ -815,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
private def createTestConf(inMemory: Boolean = false): SparkConf = {
val conf = new SparkConf()
- .set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
+ .set(EVENT_LOG_DIR, testDir.getAbsolutePath())
+ .set(FAST_IN_PROGRESS_PARSING, true)
if (!inMemory) {
conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath())
@@ -848,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider {
mappings.get(username).map(Set(_)).getOrElse(Set.empty)
}
}
-
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 87f12f303cd5e..11b29121739a4 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.json4s.JsonAST._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods._
+import org.mockito.Mockito._
import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest.{BeforeAndAfter, Matchers}
@@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND)
}
+ test("automatically retrieve uiRoot from request through Knox") {
+ assert(sys.props.get("spark.ui.proxyBase").isEmpty,
+ "spark.ui.proxyBase is defined but it should not for this UT")
+ assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty,
+ "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT")
+ val page = new HistoryPage(server)
+ val requestThroughKnox = mock[HttpServletRequest]
+ val knoxBaseUrl = "/gateway/default/sparkhistoryui"
+ when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl)
+ val responseThroughKnox = page.render(requestThroughKnox)
+
+ val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString)
+ val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/"))
+ all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl)
+
+ val directRequest = mock[HttpServletRequest]
+ val directResponse = page.render(directRequest)
+
+ val directUrls = directResponse \\ "@href" map (_.toString)
+ val directSiteRelativeLinks = directUrls filter (_.startsWith("/"))
+ all (directSiteRelativeLinks) should not startWith (knoxBaseUrl)
+ }
+
test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase")
val page = new HistoryPage(server)
@@ -296,6 +320,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
all (siteRelativeLinks) should startWith (uiRoot)
}
+ test("/version api endpoint") {
+ val response = getUrl("version")
+ assert(response.contains(SPARK_VERSION))
+ }
+
test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") {
val uiRoot = "/testwebproxybase"
System.setProperty("spark.ui.proxyBase", uiRoot)
diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
index e505bc018857d..54c168a8218f3 100644
--- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala
@@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach {
"--class", mainClass,
mainJar) ++ appArgs
val args = new SparkSubmitArguments(commandLineArgs)
- val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args)
+ val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args)
new RestSubmissionClient("spark://host:port").constructSubmitRequest(
mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty)
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
index ce212a7513310..e3fe2b696aa1f 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
@@ -17,10 +17,19 @@
package org.apache.spark.deploy.worker
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.function.Supplier
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfter, Matchers}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.{Command, ExecutorState}
+import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService}
import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged}
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.rpc.{RpcAddress, RpcEnv}
@@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
import org.apache.spark.deploy.DeployTestUtils._
+ @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _
+
def cmd(javaOpts: String*): Command = {
Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*))
}
@@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
private var _worker: Worker = _
- private def makeWorker(conf: SparkConf): Worker = {
+ private def makeWorker(
+ conf: SparkConf,
+ shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = {
assert(_worker === null, "Some Worker's RpcEnv is leaked in tests")
val securityMgr = new SecurityManager(conf)
val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr)
_worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)),
- "Worker", "/tmp", conf, securityMgr)
+ "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier)
_worker
}
+ before {
+ MockitoAnnotations.initMocks(this)
+ }
+
after {
if (_worker != null) {
_worker.rpcEnv.shutdown()
@@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
assert(worker.finishedDrivers.size === expectedValue)
}
}
+
+ test("cleanup non-shuffle files after executor exits when config " +
+ "spark.storage.cleanupFilesAfterExecutorExit=true") {
+ testCleanupFilesWithConfig(true)
+ }
+
+ test("don't cleanup non-shuffle files after executor exits when config " +
+ "spark.storage.cleanupFilesAfterExecutorExit=false") {
+ testCleanupFilesWithConfig(false)
+ }
+
+ private def testCleanupFilesWithConfig(value: Boolean) = {
+ val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString)
+
+ val cleanupCalled = new AtomicBoolean(false)
+ when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] {
+ override def answer(invocations: InvocationOnMock): Unit = {
+ cleanupCalled.set(true)
+ }
+ })
+ val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] {
+ override def get: ExternalShuffleService = shuffleService
+ }
+ val worker = makeWorker(conf, externalShuffleServiceSupplier)
+ // initialize workers
+ for (i <- 0 until 10) {
+ worker.executors += s"app1/$i" -> createExecutorRunner(i)
+ }
+ worker.handleExecutorStateChanged(
+ ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None))
+ assert(cleanupCalled.get() == value)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 105a178f2d94e..1a7bebe2c53cd 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.Map
import scala.concurrent.duration._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
- val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
+ assert(failReason.isInstanceOf[ExceptionFailure])
+ val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
+ assert(exceptionCaptor.getAllValues.size === 1)
+ assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
+ }
+
+ test("SPARK-23816: interrupts are not masked by a FetchFailure") {
+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
+ // as the fetch failure could easily be caused by interrupting the thread.
+ val (failReason, _) = testFetchFailureHandling(false)
+ assert(failReason.isInstanceOf[TaskKilled])
+ }
+
+ /**
+ * Helper for testing some cases where a FetchFailure should *not* get sent back, because its
+ * superceded by another error, either an OOM or intentionally killing a task.
+ * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
+ * FetchFailure
+ */
+ private def testFetchFailureHandling(
+ oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
+ // does not represent a real fetch failure.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
- // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
- // the fetch failure as a false positive, and just do normal OOM handling.
+ // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
+ // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
val inputRDD = new FetchFailureThrowingRDD(sc)
- val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+ if (!oom) {
+ // we are trying to setup a case where a task is killed after a fetch failure -- this
+ // is just a helper to coordinate between the task thread and this thread that will
+ // kill the task
+ ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
+ }
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)
- val (failReason, uncaughtExceptionHandler) =
- runTaskGetFailReasonAndExceptionHandler(taskDescription)
- // make sure the task failure just looks like a OOM, not a fetch failure
- assert(failReason.isInstanceOf[ExceptionFailure])
- val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
- assert(exceptionCaptor.getAllValues.size === 1)
- assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
- }
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
+ }
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
}
private def runTaskGetFailReasonAndExceptionHandler(
- taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
+ taskDescription: TaskDescription,
+ killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
+ val timedOut = new AtomicBoolean(false)
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
+ if (killTask) {
+ val killingThread = new Thread("kill-task") {
+ override def run(): Unit = {
+ // wait to kill the task until it has thrown a fetch failure
+ if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) {
+ // now we can kill the task
+ executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
+ } else {
+ timedOut.set(true)
+ }
+ }
+ }
+ killingThread.start()
+ }
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
+ assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
} finally {
if (executor != null) {
executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
+ val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
orderedMock.verify(mockBackend)
- .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
+ .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
- throwOOM: Boolean) extends RDD[Int](input) {
+ throwOOM: Boolean,
+ interrupt: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
+ } else if (interrupt) {
+ // make sure our test is setup correctly
+ assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
+ // signal our test is ready for the task to get killed
+ ExecutorSuiteHelper.latches.latch1.countDown()
+ // then wait for another thread in the test to kill the task -- this latch
+ // is never actually decremented, we just wait to get killed.
+ ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
+ throw new IllegalStateException("timed out waiting to be interrupted")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
@volatile var testFailedReason: TaskFailedReason = _
}
+// helper for coordinating killing tasks
+private object ExecutorSuiteHelper {
+ var latches: ExecutorSuiteHelper = null
+}
+
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala
new file mode 100644
index 0000000000000..2bd32fc927e21
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Unit tests for instantiation of FileCommitProtocol implementations.
+ */
+class FileCommitProtocolInstantiationSuite extends SparkFunSuite {
+
+ test("Dynamic partitions require appropriate constructor") {
+
+ // you cannot instantiate a two-arg client with dynamic partitions
+ // enabled.
+ val ex = intercept[IllegalArgumentException] {
+ instantiateClassic(true)
+ }
+ // check the contents of the message and rethrow if unexpected.
+ // this preserves the stack trace of the unexpected
+ // exception.
+ if (!ex.toString.contains("Dynamic Partition Overwrite")) {
+ fail(s"Wrong text in caught exception $ex", ex)
+ }
+ }
+
+ test("Standard partitions work with classic constructor") {
+ instantiateClassic(false)
+ }
+
+ test("Three arg constructors have priority") {
+ assert(3 == instantiateNew(false).argCount,
+ "Wrong constructor argument count")
+ }
+
+ test("Three arg constructors have priority when dynamic") {
+ assert(3 == instantiateNew(true).argCount,
+ "Wrong constructor argument count")
+ }
+
+ test("The protocol must be of the correct class") {
+ intercept[ClassCastException] {
+ FileCommitProtocol.instantiate(
+ classOf[Other].getCanonicalName,
+ "job",
+ "path",
+ false)
+ }
+ }
+
+ test("If there is no matching constructor, class hierarchy is irrelevant") {
+ intercept[NoSuchMethodException] {
+ FileCommitProtocol.instantiate(
+ classOf[NoMatchingArgs].getCanonicalName,
+ "job",
+ "path",
+ false)
+ }
+ }
+
+ /**
+ * Create a classic two-arg protocol instance.
+ * @param dynamic dyanmic partitioning mode
+ * @return the instance
+ */
+ private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = {
+ FileCommitProtocol.instantiate(
+ classOf[ClassicConstructorCommitProtocol].getCanonicalName,
+ "job",
+ "path",
+ dynamic).asInstanceOf[ClassicConstructorCommitProtocol]
+ }
+
+ /**
+ * Create a three-arg protocol instance.
+ * @param dynamic dyanmic partitioning mode
+ * @return the instance
+ */
+ private def instantiateNew(
+ dynamic: Boolean): FullConstructorCommitProtocol = {
+ FileCommitProtocol.instantiate(
+ classOf[FullConstructorCommitProtocol].getCanonicalName,
+ "job",
+ "path",
+ dynamic).asInstanceOf[FullConstructorCommitProtocol]
+ }
+
+}
+
+/**
+ * This protocol implementation does not have the new three-arg
+ * constructor.
+ */
+private class ClassicConstructorCommitProtocol(arg1: String, arg2: String)
+ extends HadoopMapReduceCommitProtocol(arg1, arg2) {
+}
+
+/**
+ * This protocol implementation does have the new three-arg constructor
+ * alongside the original, and a 4 arg one for completeness.
+ * The final value of the real constructor is the number of arguments
+ * used in the 2- and 3- constructor, for test assertions.
+ */
+private class FullConstructorCommitProtocol(
+ arg1: String,
+ arg2: String,
+ b: Boolean,
+ val argCount: Int)
+ extends HadoopMapReduceCommitProtocol(arg1, arg2, b) {
+
+ def this(arg1: String, arg2: String) = {
+ this(arg1, arg2, false, 2)
+ }
+
+ def this(arg1: String, arg2: String, b: Boolean) = {
+ this(arg1, arg2, false, 3)
+ }
+}
+
+/**
+ * This has the 2-arity constructor, but isn't the right class.
+ */
+private class Other(arg1: String, arg2: String) {
+
+}
+
+/**
+ * This has no matching arguments as well as being the wrong class.
+ */
+private class NoMatchingArgs() {
+
+}
+
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index 3b798e36b0499..2107559572d78 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -21,11 +21,12 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.network.util.ByteArrayWritableChannel
import org.apache.spark.util.io.ChunkedByteBuffer
-class ChunkedByteBufferSuite extends SparkFunSuite {
+class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext {
test("no chunks") {
val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
@@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
assert(chunkedByteBuffer.getChunks().head.position() === 0)
}
+ test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") {
+ try {
+ sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L)
+ val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024)))
+ val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt)
+ chunkedByteBuffer.writeFully(byteArrayWritableChannel)
+ assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size)
+ } finally {
+ sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE)
+ }
+ }
+
test("toArray()") {
val empty = ByteBuffer.wrap(Array.empty[Byte])
val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
index f7bc3725d7278..78423ee68a0ec 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
@@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite
private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
actualPort should be >= expectedPort
// avoid testing equality in case of simultaneous tests
+ // if `spark.testing` is true,
// the default value for `spark.port.maxRetries` is 100 under test
actualPort should be <= (expectedPort + 100)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e994d724c462f..191c61250ce21 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
}.collect()
}
+ test("SPARK-23496: order of input partitions can result in severe skew in coalesce") {
+ val numInputPartitions = 100
+ val numCoalescedPartitions = 50
+ val locations = Array("locA", "locB")
+
+ val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions)
+ assert(inputRDD.getNumPartitions == numInputPartitions)
+
+ val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) =>
+ if (p.index < numCoalescedPartitions) {
+ Seq(locations(0))
+ } else {
+ Seq(locations(1))
+ }
+ })
+ val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions)
+
+ val numPartsPerLocation = coalescedRDD
+ .getPartitions
+ .map(coalescedRDD.getPreferredLocations(_).head)
+ .groupBy(identity)
+ .mapValues(_.size)
+
+ // Make sure the coalesced partitions are distributed fairly evenly between the two locations.
+ // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed.
+ assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions)
+ assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions)
+ }
+
// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
@@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria
groups.toArray
}
}
+
+/** Alters the preferred locations of the parent RDD using provided function. */
+class LocationPrefRDD[T: ClassTag](
+ @transient var prev: RDD[T],
+ val locationPicker: Partition => Seq[String]) extends RDD[T](prev) {
+ override protected def getPartitions: Array[Partition] = prev.partitions
+
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] =
+ null.asInstanceOf[Iterator[T]]
+
+ override def getPreferredLocations(partition: Partition): Seq[String] =
+ locationPicker(partition)
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index afebcdd7b9e31..96c8404327e24 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
- when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called"))
+ when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
@@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures)
- verify(allocationClientMock, never).killExecutors(any(), any(), any())
+ verify(allocationClientMock, never).killExecutors(any(), any(), any(), any())
verify(allocationClientMock, never).killExecutorsOnHost(any())
// Enable auto-kill. Blacklist an executor and make sure killExecutors is called.
@@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures)
- verify(allocationClientMock).killExecutors(Seq("1"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("1"), false, false, true)
val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1)
// Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole
@@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
}
blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures)
- verify(allocationClientMock).killExecutors(Seq("2"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("2"), false, false, true)
verify(allocationClientMock).killExecutorsOnHost("hostA")
}
test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") {
val allocationClientMock = mock[ExecutorAllocationClient]
- when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called"))
+ when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called"))
when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] {
// To avoid a race between blacklisting and killing, it is important that the nodeBlacklist
// is updated before we ask the executor allocation client to kill all the executors
@@ -571,16 +571,19 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
conf.set(config.BLACKLIST_KILL_ENABLED, false)
blacklist.updateBlacklistForFetchFailure("hostA", exec = "1")
- verify(allocationClientMock, never).killExecutors(any(), any(), any())
+ verify(allocationClientMock, never).killExecutors(any(), any(), any(), any())
verify(allocationClientMock, never).killExecutorsOnHost(any())
+ assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
+ assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
+
// Enable auto-kill. Blacklist an executor and make sure killExecutors is called.
conf.set(config.BLACKLIST_KILL_ENABLED, true)
blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock)
clock.advance(1000)
blacklist.updateBlacklistForFetchFailure("hostA", exec = "1")
- verify(allocationClientMock).killExecutors(Seq("1"), true, true)
+ verify(allocationClientMock).killExecutors(Seq("1"), false, false, true)
verify(allocationClientMock, never).killExecutorsOnHost(any())
assert(blacklist.executorIdToBlacklistStatus.contains("1"))
@@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS)
assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty)
+ assert(blacklist.nodeToBlacklistedExecs.contains("hostA"))
+ assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1"))
// Enable external shuffle service to see if all the executors on this node will be killed.
conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d812b5bd92c1b..2987170bf5026 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assertDataStructuresEmpty()
}
- test("accumulators are updated on exception failures") {
+ test("accumulators are updated on exception failures and task killed") {
val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
val acc2 = AccumulatorSuite.createLongAccum("boulanger")
val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
@@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val accUpdate3 = new LongAccumulator
accUpdate3.metadata = acc3.metadata
accUpdate3.setValue(18)
- val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
- val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)
+
+ val accumUpdates1 = Seq(accUpdate1, accUpdate2)
+ val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo)
val exceptionFailure = new ExceptionFailure(
new SparkException("fondue?"),
- accumInfo).copy(accums = accumUpdates)
+ accumInfo1).copy(accums = accumUpdates1)
submit(new MyRDD(sc, 1, Nil), Array(0))
runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
+
assert(AccumulatorContext.get(acc1.id).get.value === 15L)
assert(AccumulatorContext.get(acc2.id).get.value === 13L)
+
+ val accumUpdates2 = Seq(accUpdate3)
+ val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo)
+
+ val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2)
+ runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result"))
+
assert(AccumulatorContext.get(acc3.id).get.value === 18L)
}
@@ -2146,6 +2155,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assertDataStructuresEmpty()
}
+ test("Trigger mapstage's job listener in submitMissingTasks") {
+ val rdd1 = new MyRDD(sc, 2, Nil)
+ val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2))
+ val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker)
+ val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2))
+
+ val listener1 = new SimpleListener
+ val listener2 = new SimpleListener
+
+ submitMapStage(dep1, listener1)
+ submitMapStage(dep2, listener2)
+
+ // Complete the stage0.
+ assert(taskSets(0).stageId === 0)
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", rdd1.partitions.length)),
+ (Success, makeMapStatus("hostB", rdd1.partitions.length))))
+ assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
+ HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
+ assert(listener1.results.size === 1)
+
+ // When attempting stage1, trigger a fetch failure.
+ assert(taskSets(1).stageId === 1)
+ complete(taskSets(1), Seq(
+ (Success, makeMapStatus("hostC", rdd2.partitions.length)),
+ (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null)))
+ scheduler.resubmitFailedStages()
+ // Stage1 listener should not have a result yet
+ assert(listener2.results.size === 0)
+
+ // Speculative task succeeded in stage1.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1),
+ Success,
+ makeMapStatus("hostD", rdd2.partitions.length)))
+ // stage1 listener still should not have a result, though there's no missing partitions
+ // in it. Because stage1 has been failed and is not inside `runningStages` at this moment.
+ assert(listener2.results.size === 0)
+
+ // Stage0 should now be running as task set 2; make its task succeed
+ assert(taskSets(2).stageId === 0)
+ complete(taskSets(2), Seq(
+ (Success, makeMapStatus("hostC", rdd2.partitions.length))))
+ assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet ===
+ Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+
+ // After stage0 is finished, stage1 will be submitted and found there is no missing
+ // partitions in it. Then listener got triggered.
+ assert(listener2.results.size === 1)
+ assertDataStructuresEmpty()
+ }
+
/**
* In this test, we run a map stage where one of the executors fails but we still receive a
* "zombie" complete message from that executor. We want to make sure the stage is not reported
@@ -2445,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val accumUpdates = reason match {
case Success => task.metrics.accumulators()
case ef: ExceptionFailure => ef.accums
+ case tk: TaskKilled => tk.accums
case _ => Seq.empty
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 73e7b3fe8c1de..e24d550a62665 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
}
test("Simple replay") {
- val logFilePath = Utils.getFilePath(testDir, "events.txt")
+ val logFilePath = getFilePath(testDir, "events.txt")
val fstream = fileSystem.create(logFilePath)
val writer = new PrintWriter(fstream)
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
@@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
// scalastyle:on println
}
- val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress")
+ val logFilePath = getFilePath(testDir, "events.lz4.inprogress")
val bytes = buffered.toByteArray
Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream =>
fstream.write(bytes, 0, buffered.size)
@@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
}
test("Replay incompatible event log") {
- val logFilePath = Utils.getFilePath(testDir, "incompatible.txt")
+ val logFilePath = getFilePath(testDir, "incompatible.txt")
val fstream = fileSystem.create(logFilePath)
val writer = new PrintWriter(fstream)
val applicationStart = SparkListenerApplicationStart("Incompatible App", None,
@@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
}
}
+ private def getFilePath(dir: File, fileName: String): Path = {
+ assert(dir.isDirectory)
+ val path = new File(dir, fileName).getAbsolutePath
+ new Path(path)
+ }
+
/**
* A simple listener that buffers all the events it receives.
*
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index da6ecb82c7e42..6ffd1e84f7adb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.Semaphore
import scala.collection.JavaConverters._
@@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
- // just to make sure some of the tasks take a noticeable amount of time
+ // just to make sure some of the tasks and their deserialization take a noticeable
+ // amount of time
+ val slowDeserializable = new SlowDeserializable
val w = { i: Int =>
if (i == 0) {
Thread.sleep(100)
+ slowDeserializable.use()
}
i
}
@@ -485,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
assert(bus.findListenersByClass[BasicJobCounter]().isEmpty)
}
+ Seq(true, false).foreach { throwInterruptedException =>
+ val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted"
+ test(s"interrupt within listener is handled correctly: $suffix") {
+ val conf = new SparkConf(false)
+ .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5)
+ val bus = new LiveListenerBus(conf)
+ val counter1 = new BasicJobCounter()
+ val counter2 = new BasicJobCounter()
+ val interruptingListener1 = new InterruptingListener(throwInterruptedException)
+ val interruptingListener2 = new InterruptingListener(throwInterruptedException)
+ bus.addToSharedQueue(counter1)
+ bus.addToSharedQueue(interruptingListener1)
+ bus.addToStatusQueue(counter2)
+ bus.addToEventLogQueue(interruptingListener2)
+ assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE))
+ assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
+ assert(bus.findListenersByClass[InterruptingListener]().size === 2)
+
+ bus.start(mockSparkContext, mockMetricsSystem)
+
+ // after we post one event, both interrupting listeners should get removed, and the
+ // event log queue should be removed
+ bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
+ bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE))
+ assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
+ assert(bus.findListenersByClass[InterruptingListener]().size === 0)
+ assert(counter1.count === 1)
+ assert(counter2.count === 1)
+
+ // posting more events should be fine, they'll just get processed from the OK queue.
+ (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
+ bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(counter1.count === 6)
+ assert(counter2.count === 6)
+
+ // Make sure stopping works -- this requires putting a poison pill in all active queues, which
+ // would fail if our interrupted queue was still active, as its queue would be full.
+ bus.stop()
+ }
+ }
+
/**
* Assert that the given list of numbers has an average that is greater than zero.
*/
@@ -543,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception }
}
+ /**
+ * A simple listener that interrupts on job end.
+ */
+ private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener {
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+ if (throwInterruptedException) {
+ throw new InterruptedException("got interrupted")
+ } else {
+ Thread.currentThread().interrupt()
+ }
+ }
+ }
}
// These classes can't be declared inside of the SparkListenerSuite class because we don't want
@@ -583,3 +641,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar
case _ =>
}
}
+
+private class SlowDeserializable extends Externalizable {
+
+ override def writeExternal(out: ObjectOutput): Unit = { }
+
+ override def readExternal(in: ObjectInput): Unit = Thread.sleep(1)
+
+ def use(): Unit = { }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 6003899bb7bef..33f2ea1c94e75 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.initialize(new FakeSchedulerBackend)
}
}
+
+ test("Completions in zombie tasksets update status of non-zombie taskset") {
+ val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
+ val valueSer = SparkEnv.get.serializer.newInstance()
+
+ def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
+ val indexInTsm = tsm.partitionToIndex(partition)
+ val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
+ }
+
+ // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
+ // two times, so we have three active task sets for one stage. (For this to really happen,
+ // you'd need the previous stage to also get restarted, and then succeed, in between each
+ // attempt, but that happens outside what we're mocking here.)
+ val zombieAttempts = (0 until 2).map { stageAttempt =>
+ val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
+ taskScheduler.submitTasks(attempt)
+ val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
+ val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ taskScheduler.resourceOffers(offers)
+ assert(tsm.runningTasks === 10)
+ // fail attempt
+ tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
+ FetchFailed(null, 0, 0, 0, "fetch failed"))
+ // the attempt is a zombie, but the tasks are still running (this could be true even if
+ // we actively killed those tasks, as killing is best-effort)
+ assert(tsm.isZombie)
+ assert(tsm.runningTasks === 9)
+ tsm
+ }
+
+ // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
+ // the stage, but this time with insufficient resources so not all tasks are active.
+
+ val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
+ taskScheduler.submitTasks(finalAttempt)
+ val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
+ val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
+ finalAttempt.tasks(task.index).partitionId
+ }.toSet
+ assert(finalTsm.runningTasks === 5)
+ assert(!finalTsm.isZombie)
+
+ // We simulate late completions from our zombie tasksets, corresponding to all the pending
+ // partitions in our final attempt. This means we're only waiting on the tasks we've already
+ // launched.
+ val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
+ finalAttemptPendingPartitions.foreach { partition =>
+ completeTaskSuccessfully(zombieAttempts(0), partition)
+ }
+
+ // If there is another resource offer, we shouldn't run anything. Though our final attempt
+ // used to have pending tasks, now those tasks have been completed by zombie attempts. The
+ // remaining tasks to compute are already active in the non-zombie attempt.
+ assert(
+ taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
+
+ val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
+
+ // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
+ // marked as zombie.
+ // for each of the remaining tasks, find the tasksets with an active copy of the task, and
+ // finish the task.
+ remainingTasks.foreach { partition =>
+ val tsm = if (partition == 0) {
+ // we failed this task on both zombie attempts, this one is only present in the latest
+ // taskset
+ finalTsm
+ } else {
+ // should be active in every taskset. We choose a zombie taskset just to make sure that
+ // we transition the active taskset correctly even if the final completion comes
+ // from a zombie.
+ zombieAttempts(partition % 2)
+ }
+ completeTaskSuccessfully(tsm, partition)
+ }
+
+ assert(finalTsm.isZombie)
+
+ // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
+ verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject())
+
+ // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
+ // else succeeds, to make sure we get the right updates to the blacklist in all cases.
+ (zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
+ val stageAttempt = tsm.taskSet.stageAttemptId
+ tsm.runningTasksSet.foreach { index =>
+ if (stageAttempt == 1) {
+ tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
+ } else {
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
+ }
+ }
+
+ // we update the blacklist for the stage attempts with all successful tasks. Even though
+ // some tasksets had failures, we still consider them all successful from a blacklisting
+ // perspective, as the failures weren't from a problem w/ the tasks themselves.
+ verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
new file mode 100644
index 0000000000000..e57cb701b6284
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.Closeable
+import java.net._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+class SocketAuthHelperSuite extends SparkFunSuite {
+
+ private val conf = new SparkConf()
+ private val authHelper = new SocketAuthHelper(conf)
+
+ test("successful auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ authHelper.authToServer(client)
+ server.close()
+ server.join()
+ assert(server.error == null)
+ assert(server.authenticated)
+ }
+ }
+ }
+
+ test("failed auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ intercept[IllegalArgumentException] {
+ badHelper.authToServer(client)
+ }
+ server.close()
+ server.join()
+ assert(server.error != null)
+ assert(!server.authenticated)
+ }
+ }
+ }
+
+ private class ServerThread extends Thread with Closeable {
+
+ private val ss = new ServerSocket()
+ ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
+
+ @volatile var error: Exception = _
+ @volatile var authenticated = false
+
+ setDaemon(true)
+ start()
+
+ def createClient(): Socket = {
+ new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ }
+
+ override def run(): Unit = {
+ var clientConn: Socket = null
+ try {
+ clientConn = ss.accept()
+ authHelper.authClient(clientConn)
+ authenticated = true
+ } catch {
+ case e: Exception =>
+ error = e
+ } finally {
+ Option(clientConn).foreach(_.close())
+ }
+ }
+
+ override def close(): Unit = {
+ try {
+ ss.close()
+ } finally {
+ interrupt()
+ }
+ }
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index dba1172d5fdbd..2d8a83c6fabed 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
(shuffleBlockId, byteOutputStream.size().toLong)
}
- Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
+ Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator
}
// Create a mocked shuffle handle to pass into HashShuffleReader.
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
index 55cebe7c8b6a8..f29dac965c803 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
mapSideCombine = false
)))
+ // We support serialized shuffle if we do not need to do map-side aggregation
+ assert(canUseSerializedShuffle(shuffleDep(
+ partitioner = new HashPartitioner(2),
+ serializer = kryo,
+ keyOrdering = None,
+ aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+ mapSideCombine = false
+ )))
}
test("unsupported shuffle dependencies for serialized shuffle") {
@@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
mapSideCombine = false
)))
- // We do not support shuffles that perform aggregation
- assert(!canUseSerializedShuffle(shuffleDep(
- partitioner = new HashPartitioner(2),
- serializer = kryo,
- keyOrdering = None,
- aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
- mapSideCombine = false
- )))
+ // We do not support serialized shuffle if we need to do map-side aggregation
assert(!canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
serializer = kryo,
diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
index b74d6ee2ec836..1cd71955ad4d9 100644
--- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
@@ -273,6 +273,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag)
}
+ check[ExecutorSummaryWrapper](execIds.head) { exec =>
+ assert(exec.info.blacklistedInStages === Set(stages.head.stageId))
+ }
+
// Blacklisting node for stage
time += 1
listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage(
@@ -439,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
assert(stage.info.numCompleteTasks === pending.size)
}
+ check[ExecutorSummaryWrapper](execIds.head) { exec =>
+ assert(exec.info.blacklistedInStages === Set())
+ }
+
// Submit stage 2.
time += 1
stages.last.submissionTime = Some(time)
@@ -453,6 +461,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get)))
}
+ // Blacklisting node for stage
+ time += 1
+ listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage(
+ time = time,
+ hostId = "1.example.com",
+ executorFailures = 1,
+ stageId = stages.last.stageId,
+ stageAttemptId = stages.last.attemptId))
+
+ check[ExecutorSummaryWrapper](execIds.head) { exec =>
+ assert(exec.info.blacklistedInStages === Set(stages.last.stageId))
+ }
+
// Start and fail all tasks of stage 2.
time += 1
val s2Tasks = createTasks(4, execIds)
@@ -1068,6 +1089,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
}
+ test("skipped stages should be evicted before completed stages") {
+ val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2)
+ val listener = new AppStatusListener(store, testConf, true)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+
+ // Sart job 1
+ time += 1
+ listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null))
+
+ // Start and stop stage 1
+ time += 1
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+
+ time += 1
+ stage1.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Stop job 1 and stage 2 will become SKIPPED
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded))
+
+ // Submit stage 3 and verify stage 2 is evicted
+ val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")
+ time += 1
+ stage3.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties()))
+
+ assert(store.count(classOf[StageDataWrapper]) === 2)
+ intercept[NoSuchElementException] {
+ store.read(classOf[StageDataWrapper], Array(2, 0))
+ }
+ }
+
test("eviction should respect task completion time") {
val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2)
val listener = new AppStatusListener(store, testConf, true)
@@ -1100,6 +1157,39 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter {
}
}
+ test("lastStageAttempt should fail when the stage doesn't exist") {
+ val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1)
+ val listener = new AppStatusListener(store, testConf, true)
+ val appStore = new AppStatusStore(store)
+
+ val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1")
+ val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2")
+ val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")
+
+ time += 1
+ stage1.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties()))
+ stage1.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage1))
+
+ // Make stage 3 complete before stage 2 so that stage 3 will be evicted
+ time += 1
+ stage3.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties()))
+ stage3.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage3))
+
+ time += 1
+ stage2.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties()))
+ stage2.completionTime = Some(time)
+ listener.onStageCompleted(SparkListenerStageCompleted(stage2))
+
+ assert(appStore.asOption(appStore.lastStageAttempt(1)) === None)
+ assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2))
+ assert(appStore.asOption(appStore.lastStageAttempt(3)) === None)
+ }
+
test("driver logs") {
val listener = new AppStatusListener(store, conf, true)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 5bfe9905ff17b..a2997dbd1b1ac 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
// Create a mock managed buffer for testing
- def createMockManagedBuffer(): ManagedBuffer = {
+ def createMockManagedBuffer(size: Int = 1): ManagedBuffer = {
val mockManagedBuffer = mock(classOf[ManagedBuffer])
val in = mock(classOf[InputStream])
when(in.read(any())).thenReturn(1)
when(in.read(any(), any(), any())).thenReturn(1)
when(mockManagedBuffer.createInputStream()).thenReturn(in)
+ when(mockManagedBuffer.size()).thenReturn(size)
mockManagedBuffer
}
@@ -99,7 +100,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
- )
+ ).toIterator
val iterator = new ShuffleBlockFetcherIterator(
TaskContext.empty(),
@@ -176,7 +177,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -244,7 +245,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
+ private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = {
+ val corruptStream = mock(classOf[InputStream])
+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+ val corruptBuffer = mock(classOf[ManagedBuffer])
+ when(corruptBuffer.size()).thenReturn(size)
+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+ corruptBuffer
+ }
+
test("retry corrupt blocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
@@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
-
- val corruptStream = mock(classOf[InputStream])
- when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
- val corruptBuffer = mock(classOf[ManagedBuffer])
- when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
val transfer = mock(classOf[BlockTransferService])
@@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
sem.release()
@@ -310,7 +315,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
sem.release()
}
}
@@ -352,6 +357,47 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
+ test("big blocks are not checked for corruption") {
+ val corruptBuffer = mockCorruptBuffer(10000L)
+
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+ doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0))
+ val localBlockLengths = Seq[Tuple2[BlockId, Long]](
+ ShuffleBlockId(0, 0, 0) -> corruptBuffer.size()
+ )
+
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlockLengths = Seq[Tuple2[BlockId, Long]](
+ ShuffleBlockId(0, 1, 0) -> corruptBuffer.size()
+ )
+
+ val transfer = createMockTransfer(
+ Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer))
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (localBmId, localBlockLengths),
+ (remoteBmId, remoteBlockLengths)
+ ).toIterator
+
+ val taskContext = TaskContext.empty()
+ val iterator = new ShuffleBlockFetcherIterator(
+ taskContext,
+ transfer,
+ blockManager,
+ blocksByAddress,
+ (_, in) => new LimitedInputStream(in, 10000),
+ 2048,
+ Int.MaxValue,
+ Int.MaxValue,
+ Int.MaxValue,
+ true)
+ // Blocks should be returned without exceptions.
+ assert(Set(iterator.next()._1, iterator.next()._1) ===
+ Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0)))
+ }
+
test("retry corrupt blocks (disabled)") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
@@ -368,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
- val corruptStream = mock(classOf[InputStream])
- when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
- val corruptBuffer = mock(classOf[ManagedBuffer])
- when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
-
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
@@ -383,16 +424,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+ ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
- ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
+ ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
sem.release()
}
}
})
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
@@ -450,7 +491,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
})
- def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
+ def fetchShuffleBlock(
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
// Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
// construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
// are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
@@ -468,17 +510,52 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator
fetchShuffleBlock(blocksByAddress1)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
// shuffle block to disk.
assert(tempFileManager == null)
val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator
fetchShuffleBlock(blocksByAddress2)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
// shuffle block to disk.
assert(tempFileManager != null)
}
+
+ test("fail zero-size blocks") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()
+ )
+
+ val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)))
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+ val taskContext = TaskContext.empty()
+ val iterator = new ShuffleBlockFetcherIterator(
+ taskContext,
+ transfer,
+ blockManager,
+ blocksByAddress.toIterator,
+ (_, in) => in,
+ 48 * 1024 * 1024,
+ Int.MaxValue,
+ Int.MaxValue,
+ Int.MaxValue,
+ true)
+
+ // All blocks fetched return zero length and should trigger a receive-side error:
+ val e = intercept[FetchFailedException] { iterator.next() }
+ assert(e.getMessage.contains("Received a zero-size buffer"))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
index da198f946fd64..ca352387055f4 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala
@@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite {
assert(status.diskUsed === 60L)
}
- test("storage status update non-RDD blocks") {
- val status = storageStatus1
- status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L))
- status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L))
- assert(status.blocks.size === 3)
- assert(status.memUsed === 160L)
- assert(status.memRemaining === 840L)
- assert(status.diskUsed === 140L)
- }
-
- test("storage status remove non-RDD blocks") {
- val status = storageStatus1
- status.removeBlock(TestBlockId("foo"))
- status.removeBlock(TestBlockId("faa"))
- assert(status.blocks.size === 1)
- assert(status.blocks.contains(TestBlockId("fee")))
- assert(status.memUsed === 10L)
- assert(status.memRemaining === 990L)
- assert(status.diskUsed === 20L)
- }
-
// For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks
private def storageStatus2: StorageStatus = {
val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L))
@@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite {
assert(status.rddBlocks.contains(RDDBlockId(2, 2)))
assert(status.rddBlocks.contains(RDDBlockId(2, 3)))
assert(status.rddBlocks.contains(RDDBlockId(2, 4)))
- assert(status.rddBlocksById(0).size === 1)
- assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0)))
- assert(status.rddBlocksById(1).size === 1)
- assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1)))
- assert(status.rddBlocksById(2).size === 3)
- assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2)))
- assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3)))
- assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4)))
- assert(status.memUsedByRdd(0) === 10L)
- assert(status.memUsedByRdd(1) === 100L)
- assert(status.memUsedByRdd(2) === 30L)
- assert(status.diskUsedByRdd(0) === 20L)
- assert(status.diskUsedByRdd(1) === 200L)
- assert(status.diskUsedByRdd(2) === 80L)
- assert(status.rddStorageLevel(0) === Some(memAndDisk))
- assert(status.rddStorageLevel(1) === Some(memAndDisk))
- assert(status.rddStorageLevel(2) === Some(memAndDisk))
-
- // Verify default values for RDDs that don't exist
- assert(status.rddBlocksById(10).isEmpty)
- assert(status.memUsedByRdd(10) === 0L)
- assert(status.diskUsedByRdd(10) === 0L)
- assert(status.rddStorageLevel(10) === None)
- }
-
- test("storage status update RDD blocks") {
- val status = storageStatus2
- status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L))
- status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L))
- status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L))
- assert(status.blocks.size === 7)
- assert(status.rddBlocks.size === 5)
- assert(status.rddBlocksById(0).size === 1)
- assert(status.rddBlocksById(1).size === 1)
- assert(status.rddBlocksById(2).size === 3)
- assert(status.memUsedByRdd(0) === 0L)
- assert(status.memUsedByRdd(1) === 100L)
- assert(status.memUsedByRdd(2) === 20L)
- assert(status.diskUsedByRdd(0) === 0L)
- assert(status.diskUsedByRdd(1) === 200L)
- assert(status.diskUsedByRdd(2) === 1060L)
- }
-
- test("storage status remove RDD blocks") {
- val status = storageStatus2
- status.removeBlock(TestBlockId("man"))
- status.removeBlock(RDDBlockId(1, 1))
- status.removeBlock(RDDBlockId(2, 2))
- status.removeBlock(RDDBlockId(2, 4))
- assert(status.blocks.size === 3)
- assert(status.rddBlocks.size === 2)
- assert(status.rddBlocks.contains(RDDBlockId(0, 0)))
- assert(status.rddBlocks.contains(RDDBlockId(2, 3)))
- assert(status.rddBlocksById(0).size === 1)
- assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0)))
- assert(status.rddBlocksById(1).size === 0)
- assert(status.rddBlocksById(2).size === 1)
- assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3)))
- assert(status.memUsedByRdd(0) === 10L)
- assert(status.memUsedByRdd(1) === 0L)
- assert(status.memUsedByRdd(2) === 10L)
- assert(status.diskUsedByRdd(0) === 20L)
- assert(status.diskUsedByRdd(1) === 0L)
- assert(status.diskUsedByRdd(2) === 20L)
- }
-
- test("storage status containsBlock") {
- val status = storageStatus2
- // blocks that actually exist
- assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan")))
- assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man")))
- assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0)))
- assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1)))
- assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2)))
- assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3)))
- assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4)))
- // blocks that don't exist
- assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan")))
- assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0)))
}
test("storage status getBlock") {
@@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite {
assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0)))
}
- test("storage status num[Rdd]Blocks") {
- val status = storageStatus2
- assert(status.blocks.size === status.numBlocks)
- assert(status.rddBlocks.size === status.numRddBlocks)
- status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L))
- status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L))
- status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L))
- assert(status.blocks.size === status.numBlocks)
- assert(status.rddBlocks.size === status.numRddBlocks)
- assert(status.rddBlocksById(4).size === status.numRddBlocksById(4))
- assert(status.rddBlocksById(10).size === status.numRddBlocksById(10))
- status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L))
- status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L))
- status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L))
- status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L))
- assert(status.blocks.size === status.numBlocks)
- assert(status.rddBlocks.size === status.numRddBlocks)
- assert(status.rddBlocksById(4).size === status.numRddBlocksById(4))
- assert(status.rddBlocksById(10).size === status.numRddBlocksById(10))
- assert(status.rddBlocksById(100).size === status.numRddBlocksById(100))
- status.removeBlock(RDDBlockId(4, 0))
- status.removeBlock(RDDBlockId(10, 10))
- assert(status.blocks.size === status.numBlocks)
- assert(status.rddBlocks.size === status.numRddBlocks)
- assert(status.rddBlocksById(4).size === status.numRddBlocksById(4))
- assert(status.rddBlocksById(10).size === status.numRddBlocksById(10))
- // remove a block that doesn't exist
- status.removeBlock(RDDBlockId(1000, 999))
- assert(status.blocks.size === status.numBlocks)
- assert(status.rddBlocks.size === status.numRddBlocks)
- assert(status.rddBlocksById(4).size === status.numRddBlocksById(4))
- assert(status.rddBlocksById(10).size === status.numRddBlocksById(10))
- assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000))
- }
test("storage status memUsed, diskUsed, externalBlockStoreUsed") {
val status = storageStatus2
@@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite {
status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L))
assert(status.memUsed === actualMemUsed)
assert(status.diskUsed === actualDiskUsed)
- status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L))
- status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L))
- status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L))
- assert(status.memUsed === actualMemUsed)
- assert(status.diskUsed === actualDiskUsed)
- status.removeBlock(TestBlockId("fire"))
- status.removeBlock(TestBlockId("man"))
- status.removeBlock(RDDBlockId(2, 2))
- status.removeBlock(RDDBlockId(2, 3))
- assert(status.memUsed === actualMemUsed)
- assert(status.diskUsed === actualDiskUsed)
}
// For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations
@@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite {
Seq(info0, info1)
}
- test("StorageUtils.updateRddInfo") {
- val storageStatuses = stockStorageStatuses
- val rddInfos = stockRDDInfos
- StorageUtils.updateRddInfo(rddInfos, storageStatuses)
- assert(rddInfos(0).storageLevel === memAndDisk)
- assert(rddInfos(0).numCachedPartitions === 5)
- assert(rddInfos(0).memSize === 5L)
- assert(rddInfos(0).diskSize === 10L)
- assert(rddInfos(0).externalBlockStoreSize === 0L)
- assert(rddInfos(1).storageLevel === memAndDisk)
- assert(rddInfos(1).numCachedPartitions === 3)
- assert(rddInfos(1).memSize === 3L)
- assert(rddInfos(1).diskSize === 6L)
- assert(rddInfos(1).externalBlockStoreSize === 0L)
- }
-
- test("StorageUtils.getRddBlockLocations") {
- val storageStatuses = stockStorageStatuses
- val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses)
- val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses)
- assert(blockLocations0.size === 5)
- assert(blockLocations1.size === 3)
- assert(blockLocations0.contains(RDDBlockId(0, 0)))
- assert(blockLocations0.contains(RDDBlockId(0, 1)))
- assert(blockLocations0.contains(RDDBlockId(0, 2)))
- assert(blockLocations0.contains(RDDBlockId(0, 3)))
- assert(blockLocations0.contains(RDDBlockId(0, 4)))
- assert(blockLocations1.contains(RDDBlockId(1, 0)))
- assert(blockLocations1.contains(RDDBlockId(1, 1)))
- assert(blockLocations1.contains(RDDBlockId(1, 2)))
- assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1"))
- assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1"))
- assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2"))
- assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2"))
- assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3"))
- assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2"))
- assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2"))
- assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3"))
- }
-
- test("StorageUtils.getRddBlockLocations with multiple locations") {
- val storageStatuses = stockStorageStatuses
- storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L))
- storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L))
- storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L))
- val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses)
- val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses)
- assert(blockLocations0.size === 5)
- assert(blockLocations1.size === 3)
- assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3"))
- assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1"))
- assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2"))
- assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2"))
- assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3"))
- assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2"))
- assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2"))
- assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3"))
- }
-
private val offheap = StorageLevel.OFF_HEAP
// For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap
// and offheap blocks
@@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite {
status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L))
assert(status.memUsed === actualMemUsed)
assert(status.diskUsed === actualDiskUsed)
-
- status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L))
- status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L))
- status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L))
- assert(status.memUsed === actualMemUsed)
- assert(status.diskUsed === actualDiskUsed)
- assert(status.onHeapMemUsed.get === actualOnHeapMemUsed)
- assert(status.offHeapMemUsed.get === actualOffHeapMemUsed)
-
- status.removeBlock(TestBlockId("fire"))
- status.removeBlock(TestBlockId("man"))
- status.removeBlock(RDDBlockId(2, 2))
- status.removeBlock(RDDBlockId(2, 3))
- assert(status.memUsed === actualMemUsed)
- assert(status.diskUsed === actualDiskUsed)
}
private def storageStatus4: StorageStatus = {
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 0aeddf730cd35..6044563f7dde7 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -28,13 +28,74 @@ import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.status.AppStatusStore
+import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus}
import org.apache.spark.status.config._
-import org.apache.spark.ui.jobs.{StagePage, StagesTab}
+import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable}
class StagePageSuite extends SparkFunSuite with LocalSparkContext {
private val peakExecutionMemory = 10
+ test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") {
+ val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L)
+ val statusStore = AppStatusStore.createLiveStore(conf)
+ try {
+ val stageData = new StageData(
+ status = StageStatus.ACTIVE,
+ stageId = 1,
+ attemptId = 1,
+ numTasks = 1,
+ numActiveTasks = 1,
+ numCompleteTasks = 1,
+ numFailedTasks = 1,
+ numKilledTasks = 1,
+ numCompletedIndices = 1,
+
+ executorRunTime = 1L,
+ executorCpuTime = 1L,
+ submissionTime = None,
+ firstTaskLaunchedTime = None,
+ completionTime = None,
+ failureReason = None,
+
+ inputBytes = 1L,
+ inputRecords = 1L,
+ outputBytes = 1L,
+ outputRecords = 1L,
+ shuffleReadBytes = 1L,
+ shuffleReadRecords = 1L,
+ shuffleWriteBytes = 1L,
+ shuffleWriteRecords = 1L,
+ memoryBytesSpilled = 1L,
+ diskBytesSpilled = 1L,
+
+ name = "stage1",
+ description = Some("description"),
+ details = "detail",
+ schedulingPool = "pool1",
+
+ rddIds = Seq(1),
+ accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")),
+ tasks = None,
+ executorSummary = None,
+ killedTasksSummary = Map.empty
+ )
+ val taskTable = new TaskPagedTable(
+ stageData,
+ basePath = "/a/b/c",
+ currentTime = 0,
+ pageSize = 10,
+ sortColumn = "Index",
+ desc = false,
+ store = statusStore
+ )
+ val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet
+ assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet)
+ } finally {
+ statusStore.close()
+ }
+ }
+
test("peak execution memory should displayed") {
val html = renderStagePage().toString().toLowerCase(Locale.ROOT)
val targetString = "peak execution memory"
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index ed51fc445fdfb..e86cadfeebcff 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -707,6 +707,23 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B
}
}
+ test("stages page should show skipped stages") {
+ withSpark(newSparkContext()) { sc =>
+ val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache()
+ rdd.count()
+ rdd.count()
+
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ goToUi(sc, "/stages")
+ find(id("skipped")).get.text should be("Skipped Stages (1)")
+ }
+ val stagesJson = getJson(sc.ui.get, "stages")
+ stagesJson.children.size should be (4)
+ val stagesStatus = stagesJson.children.map(_ \ "status")
+ stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1)
+ }
+ }
+
def goToUi(sc: SparkContext, path: String): Unit = {
goToUi(sc.ui.get, path)
}
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala
index a71521c91d2f2..cdc7f541b9552 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ui.storage
+import javax.servlet.http.HttpServletRequest
+
import org.mockito.Mockito._
import org.apache.spark.SparkFunSuite
@@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite {
val storageTab = mock(classOf[StorageTab])
when(storageTab.basePath).thenReturn("http://localhost:4040")
val storagePage = new StoragePage(storageTab, null)
+ val request = mock(classOf[HttpServletRequest])
test("rddTable") {
val rdd1 = new RDDStorageInfo(1,
@@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite {
None,
None)
- val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3))
+ val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3))
val headers = Seq(
"ID",
@@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite {
}
test("empty rddTable") {
- assert(storagePage.rddTable(Seq.empty).isEmpty)
+ assert(storagePage.rddTable(request, Seq.empty).isEmpty)
}
test("streamBlockStorageLevelDescriptionAndSize") {
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
index a04644d57ed88..fe0a9a471a651 100644
--- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import org.apache.spark._
+import org.apache.spark.serializer.JavaSerializer
class AccumulatorV2Suite extends SparkFunSuite {
@@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc3.isZero)
assert(acc3.value === "")
}
+
+ test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
+ class MyData(val i: Int) extends Serializable
+ val param = new AccumulatorParam[MyData] {
+ override def zero(initialValue: MyData): MyData = new MyData(0)
+ override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
+ }
+
+ val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
+ acc.metadata = AccumulatorMetadata(
+ AccumulatorContext.newId(),
+ Some("test"),
+ countFailedValues = false)
+ AccumulatorContext.register(acc)
+
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ ser.serialize(acc)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 4abbb8e7894f5..74b72d940eeef 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite {
test("SparkListenerJobStart backward compatibility") {
// Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property.
val stageIds = Seq[Int](1, 2, 3, 4)
- val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500))
+ val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L))
val dummyStageInfos =
stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown"))
val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties)
@@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite {
// Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property.
// Also, SparkListenerJobEnd did not have a "Completion Time" property.
val stageIds = Seq[Int](1, 2, 3, 4)
- val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50))
+ val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L))
val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties)
val oldStartEvent = JsonProtocol.jobStartToJson(jobStart)
.removeField({ _._1 == "Submission Time"})
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index eaea6b030c154..418d2f9b88500 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
test("fetch hcfs dir") {
val tempDir = Utils.createTempDir()
val sourceDir = new File(tempDir, "source-dir")
+ sourceDir.mkdir()
val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath)
val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir)
val targetDir = new File(tempDir, "target-dir")
@@ -1167,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port")
}
}
+
+ object MalformedClassObject {
+ class MalformedClass
+ }
+
+ test("Safe getSimpleName") {
+ // getSimpleName on class of MalformedClass will result in error: Malformed class name
+ // Utils.getSimpleName works
+ val err = intercept[java.lang.InternalError] {
+ classOf[MalformedClassObject.MalformedClass].getSimpleName
+ }
+ assert(err.getMessage === "Malformed class name")
+
+ assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) ===
+ "UtilsSuite$MalformedClassObject$MalformedClass")
+ }
}
private class SimpleExtension
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 47173b89e91e2..3e56db5ea116a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark._
import org.apache.spark.memory.MemoryTestingUtils
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.unsafe.array.LongArray
-import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.unsafe.memory.OnHeapMemoryBlock
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat}
class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
@@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999]
// that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi()
val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i }
- val buf = new LongArray(MemoryBlock.fromLongArray(ref))
- val tmp = new Array[Long](size/2)
- val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
+ val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref))
+ val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L))
new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index d5956ea32096a..ddf3740e76a7a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -27,7 +27,7 @@ import com.google.common.primitives.Ints
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.unsafe.array.LongArray
-import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.unsafe.memory.OnHeapMemoryBlock
import org.apache.spark.util.collection.Sorter
import org.apache.spark.util.random.XORShiftRandom
@@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging {
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
- (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
+ (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended)))
}
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
- (new LongArray(MemoryBlock.fromLongArray(ref)),
- new LongArray(MemoryBlock.fromLongArray(extended)))
+ (new LongArray(OnHeapMemoryBlock.fromArray(ref)),
+ new LongArray(OnHeapMemoryBlock.fromArray(extended)))
}
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
@@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
}
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
- val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
+ val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index 243fbe3e1bc24..9552d001a079c 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -105,3 +105,4 @@ META-INF/*
spark-warehouse
structured-streaming/*
kafka-source-initial-offset-version-2.1.0.bin
+kafka-source-initial-offset-future-version.bin
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index a3579f21fc539..5faa3d3260a56 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then
tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \
--detach-sig spark-$SPARK_VERSION.tgz
- echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \
- spark-$SPARK_VERSION.tgz.md5
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512
rm -rf spark-$SPARK_VERSION
@@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \
--output $R_DIST_NAME.asc \
--detach-sig $R_DIST_NAME
- echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
- MD5 $R_DIST_NAME > \
- $R_DIST_NAME.md5
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
SHA512 $R_DIST_NAME > \
$R_DIST_NAME.sha512
@@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \
--output $PYTHON_DIST_NAME.asc \
--detach-sig $PYTHON_DIST_NAME
- echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
- MD5 $PYTHON_DIST_NAME > \
- $PYTHON_DIST_NAME.md5
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
SHA512 $PYTHON_DIST_NAME > \
$PYTHON_DIST_NAME.sha512
@@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \
--output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \
--detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz
- echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
- MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \
- spark-$SPARK_VERSION-bin-$NAME.tgz.md5
echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \
SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \
spark-$SPARK_VERSION-bin-$NAME.tgz.sha512
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 48e54568e6fc6..723180a14febb 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -48,7 +48,7 @@ commons-lang-2.6.jar
commons-lang3-3.5.jar
commons-logging-1.1.3.jar
commons-math3-3.4.1.jar
-commons-net-2.2.jar
+commons-net-3.1.jar
commons-pool-1.5.4.jar
compress-lzf-1.0.3.jar
core-1.1.2.jar
@@ -157,20 +157,20 @@ objenesis-2.1.jar
okhttp-3.8.1.jar
okio-1.13.0.jar
opencsv-2.3.jar
-orc-core-1.4.1-nohive.jar
-orc-mapreduce-1.4.1-nohive.jar
+orc-core-1.4.4-nohive.jar
+orc-mapreduce-1.4.4-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.8.jar
-parquet-column-1.8.2.jar
-parquet-common-1.8.2.jar
-parquet-encoding-1.8.2.jar
-parquet-format-2.3.1.jar
-parquet-hadoop-1.8.2.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
parquet-hadoop-bundle-1.6.0.jar
-parquet-jackson-1.8.2.jar
+parquet-jackson-1.10.0.jar
protobuf-java-2.5.0.jar
-py4j-0.10.6.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
@@ -182,7 +182,7 @@ slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snakeyaml-1.15.jar
snappy-0.2.jar
-snappy-java-1.1.2.6.jar
+snappy-java-1.1.7.1.jar
spire-macros_2.11-0.13.0.jar
spire_2.11-0.13.0.jar
stax-api-1.0-2.jar
@@ -190,7 +190,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.9.jar
+univocity-parsers-2.6.3.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 1807a77900e52..ea08a001a1c9b 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -48,7 +48,7 @@ commons-lang-2.6.jar
commons-lang3-3.5.jar
commons-logging-1.1.3.jar
commons-math3-3.4.1.jar
-commons-net-2.2.jar
+commons-net-3.1.jar
commons-pool-1.5.4.jar
compress-lzf-1.0.3.jar
core-1.1.2.jar
@@ -158,20 +158,20 @@ objenesis-2.1.jar
okhttp-3.8.1.jar
okio-1.13.0.jar
opencsv-2.3.jar
-orc-core-1.4.1-nohive.jar
-orc-mapreduce-1.4.1-nohive.jar
+orc-core-1.4.4-nohive.jar
+orc-mapreduce-1.4.4-nohive.jar
oro-2.0.8.jar
osgi-resource-locator-1.0.1.jar
paranamer-2.8.jar
-parquet-column-1.8.2.jar
-parquet-common-1.8.2.jar
-parquet-encoding-1.8.2.jar
-parquet-format-2.3.1.jar
-parquet-hadoop-1.8.2.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
parquet-hadoop-bundle-1.6.0.jar
-parquet-jackson-1.8.2.jar
+parquet-jackson-1.10.0.jar
protobuf-java-2.5.0.jar
-py4j-0.10.6.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
@@ -183,7 +183,7 @@ slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snakeyaml-1.15.jar
snappy-0.2.jar
-snappy-java-1.1.2.6.jar
+snappy-java-1.1.7.1.jar
spire-macros_2.11-0.13.0.jar
spire_2.11-0.13.0.jar
stax-api-1.0-2.jar
@@ -191,7 +191,7 @@ stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-2.5.9.jar
+univocity-parsers-2.6.3.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1
new file mode 100644
index 0000000000000..da874026d7d10
--- /dev/null
+++ b/dev/deps/spark-deps-hadoop-3.1
@@ -0,0 +1,221 @@
+HikariCP-java7-2.4.12.jar
+JavaEWAH-0.3.2.jar
+RoaringBitmap-0.5.11.jar
+ST4-4.0.4.jar
+accessors-smart-1.2.jar
+activation-1.1.1.jar
+aircompressor-0.8.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.7.jar
+aopalliance-1.0.jar
+aopalliance-repackaged-2.4.0-b34.jar
+apache-log4j-extras-1.2.17.jar
+arpack_combined_all-0.1.jar
+arrow-format-0.8.0.jar
+arrow-memory-0.8.0.jar
+arrow-vector-0.8.0.jar
+automaton-1.11-8.jar
+avro-1.7.7.jar
+avro-ipc-1.7.7.jar
+avro-mapred-1.7.7-hadoop2.jar
+base64-2.3.8.jar
+bcprov-jdk15on-1.58.jar
+bonecp-0.8.0.RELEASE.jar
+breeze-macros_2.11-0.13.2.jar
+breeze_2.11-0.13.2.jar
+calcite-avatica-1.2.0-incubating.jar
+calcite-core-1.2.0-incubating.jar
+calcite-linq4j-1.2.0-incubating.jar
+chill-java-0.8.4.jar
+chill_2.11-0.8.4.jar
+commons-beanutils-1.9.3.jar
+commons-cli-1.2.jar
+commons-codec-1.10.jar
+commons-collections-3.2.2.jar
+commons-compiler-3.0.8.jar
+commons-compress-1.4.1.jar
+commons-configuration2-2.1.1.jar
+commons-crypto-1.0.0.jar
+commons-daemon-1.0.13.jar
+commons-dbcp-1.4.jar
+commons-httpclient-3.1.jar
+commons-io-2.4.jar
+commons-lang-2.6.jar
+commons-lang3-3.5.jar
+commons-logging-1.1.3.jar
+commons-math3-3.4.1.jar
+commons-net-3.1.jar
+commons-pool-1.5.4.jar
+compress-lzf-1.0.3.jar
+core-1.1.2.jar
+curator-client-2.12.0.jar
+curator-framework-2.12.0.jar
+curator-recipes-2.12.0.jar
+datanucleus-api-jdo-3.2.6.jar
+datanucleus-core-3.2.10.jar
+datanucleus-rdbms-3.2.9.jar
+derby-10.12.1.1.jar
+dnsjava-2.1.7.jar
+ehcache-3.3.1.jar
+eigenbase-properties-1.1.5.jar
+flatbuffers-1.2.0-3f79e055.jar
+generex-1.0.1.jar
+geronimo-jcache_1.0_spec-1.0-alpha-1.jar
+gson-2.2.4.jar
+guava-14.0.1.jar
+guice-4.0.jar
+guice-servlet-4.0.jar
+hadoop-annotations-3.1.0.jar
+hadoop-auth-3.1.0.jar
+hadoop-client-3.1.0.jar
+hadoop-common-3.1.0.jar
+hadoop-hdfs-client-3.1.0.jar
+hadoop-mapreduce-client-common-3.1.0.jar
+hadoop-mapreduce-client-core-3.1.0.jar
+hadoop-mapreduce-client-jobclient-3.1.0.jar
+hadoop-yarn-api-3.1.0.jar
+hadoop-yarn-client-3.1.0.jar
+hadoop-yarn-common-3.1.0.jar
+hadoop-yarn-registry-3.1.0.jar
+hadoop-yarn-server-common-3.1.0.jar
+hadoop-yarn-server-web-proxy-3.1.0.jar
+hk2-api-2.4.0-b34.jar
+hk2-locator-2.4.0-b34.jar
+hk2-utils-2.4.0-b34.jar
+hppc-0.7.2.jar
+htrace-core4-4.1.0-incubating.jar
+httpclient-4.5.4.jar
+httpcore-4.4.8.jar
+ivy-2.4.0.jar
+jackson-annotations-2.6.7.jar
+jackson-core-2.6.7.jar
+jackson-core-asl-1.9.13.jar
+jackson-databind-2.6.7.1.jar
+jackson-dataformat-yaml-2.6.7.jar
+jackson-jaxrs-base-2.7.8.jar
+jackson-jaxrs-json-provider-2.7.8.jar
+jackson-mapper-asl-1.9.13.jar
+jackson-module-jaxb-annotations-2.6.7.jar
+jackson-module-paranamer-2.7.9.jar
+jackson-module-scala_2.11-2.6.7.1.jar
+janino-3.0.8.jar
+java-xmlbuilder-1.1.jar
+javassist-3.18.1-GA.jar
+javax.annotation-api-1.2.jar
+javax.inject-1.jar
+javax.inject-2.4.0-b34.jar
+javax.servlet-api-3.1.0.jar
+javax.ws.rs-api-2.0.1.jar
+javolution-5.5.1.jar
+jaxb-api-2.2.11.jar
+jcip-annotations-1.0-1.jar
+jcl-over-slf4j-1.7.16.jar
+jdo-api-3.0.1.jar
+jersey-client-2.22.2.jar
+jersey-common-2.22.2.jar
+jersey-container-servlet-2.22.2.jar
+jersey-container-servlet-core-2.22.2.jar
+jersey-guava-2.22.2.jar
+jersey-media-jaxb-2.22.2.jar
+jersey-server-2.22.2.jar
+jets3t-0.9.4.jar
+jetty-webapp-9.3.20.v20170531.jar
+jetty-xml-9.3.20.v20170531.jar
+jline-2.12.1.jar
+joda-time-2.9.3.jar
+jodd-core-3.5.2.jar
+jpam-1.1.jar
+json-smart-2.3.jar
+json4s-ast_2.11-3.5.3.jar
+json4s-core_2.11-3.5.3.jar
+json4s-jackson_2.11-3.5.3.jar
+json4s-scalap_2.11-3.5.3.jar
+jsp-api-2.1.jar
+jsr305-1.3.9.jar
+jta-1.1.jar
+jtransforms-2.4.0.jar
+jul-to-slf4j-1.7.16.jar
+kerb-admin-1.0.1.jar
+kerb-client-1.0.1.jar
+kerb-common-1.0.1.jar
+kerb-core-1.0.1.jar
+kerb-crypto-1.0.1.jar
+kerb-identity-1.0.1.jar
+kerb-server-1.0.1.jar
+kerb-simplekdc-1.0.1.jar
+kerb-util-1.0.1.jar
+kerby-asn1-1.0.1.jar
+kerby-config-1.0.1.jar
+kerby-pkix-1.0.1.jar
+kerby-util-1.0.1.jar
+kerby-xdr-1.0.1.jar
+kryo-shaded-3.0.3.jar
+kubernetes-client-3.0.0.jar
+kubernetes-model-2.0.0.jar
+leveldbjni-all-1.8.jar
+libfb303-0.9.3.jar
+libthrift-0.9.3.jar
+log4j-1.2.17.jar
+logging-interceptor-3.8.1.jar
+lz4-java-1.4.0.jar
+machinist_2.11-0.6.1.jar
+macro-compat_2.11-1.1.1.jar
+mesos-1.4.0-shaded-protobuf.jar
+metrics-core-3.1.5.jar
+metrics-graphite-3.1.5.jar
+metrics-json-3.1.5.jar
+metrics-jvm-3.1.5.jar
+minlog-1.3.0.jar
+mssql-jdbc-6.2.1.jre7.jar
+netty-3.9.9.Final.jar
+netty-all-4.1.17.Final.jar
+nimbus-jose-jwt-4.41.1.jar
+objenesis-2.1.jar
+okhttp-2.7.5.jar
+okhttp-3.8.1.jar
+okio-1.13.0.jar
+opencsv-2.3.jar
+orc-core-1.4.4-nohive.jar
+orc-mapreduce-1.4.4-nohive.jar
+oro-2.0.8.jar
+osgi-resource-locator-1.0.1.jar
+paranamer-2.8.jar
+parquet-column-1.10.0.jar
+parquet-common-1.10.0.jar
+parquet-encoding-1.10.0.jar
+parquet-format-2.4.0.jar
+parquet-hadoop-1.10.0.jar
+parquet-hadoop-bundle-1.6.0.jar
+parquet-jackson-1.10.0.jar
+protobuf-java-2.5.0.jar
+py4j-0.10.7.jar
+pyrolite-4.13.jar
+re2j-1.1.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
+scala-parser-combinators_2.11-1.0.4.jar
+scala-reflect-2.11.8.jar
+scala-xml_2.11-1.0.5.jar
+shapeless_2.11-2.3.2.jar
+slf4j-api-1.7.16.jar
+slf4j-log4j12-1.7.16.jar
+snakeyaml-1.15.jar
+snappy-0.2.jar
+snappy-java-1.1.7.1.jar
+spire-macros_2.11-0.13.0.jar
+spire_2.11-0.13.0.jar
+stax-api-1.0.1.jar
+stax2-api-3.1.4.jar
+stream-2.7.0.jar
+stringtemplate-3.2.1.jar
+super-csv-2.2.0.jar
+token-provider-1.0.1.jar
+univocity-parsers-2.6.3.jar
+validation-api-1.1.0.Final.jar
+woodstox-core-5.0.3.jar
+xbean-asm5-shaded-4.4.jar
+xz-1.0.jar
+zjsonpatch-0.3.0.jar
+zookeeper-3.4.9.jar
+zstd-jni-1.3.2-2.jar
diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh
index 8b02446b2f15f..84233c64caa9c 100755
--- a/dev/make-distribution.sh
+++ b/dev/make-distribution.sh
@@ -72,9 +72,17 @@ while (( "$#" )); do
--help)
exit_with_usage
;;
- *)
+ --*)
+ echo "Error: $1 is not supported"
+ exit_with_usage
+ ;;
+ -*)
break
;;
+ *)
+ echo "Error: $1 is not supported"
+ exit_with_usage
+ ;;
esac
shift
done
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 6b244d8184b2c..7f46a1c8f6a7c 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -101,14 +101,15 @@ def continue_maybe(prompt):
def clean_up():
- print("Restoring head pointer to %s" % original_head)
- run_cmd("git checkout %s" % original_head)
+ if 'original_head' in globals():
+ print("Restoring head pointer to %s" % original_head)
+ run_cmd("git checkout %s" % original_head)
- branches = run_cmd("git branch").replace(" ", "").split("\n")
+ branches = run_cmd("git branch").replace(" ", "").split("\n")
- for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches):
- print("Deleting local branch %s" % branch)
- run_cmd("git branch -D %s" % branch)
+ for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches):
+ print("Deleting local branch %s" % branch)
+ run_cmd("git branch -D %s" % branch)
# merge the requested PR and return the merge hash
@@ -510,7 +511,7 @@ def main():
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
try:
main()
except:
diff --git a/dev/run-pip-tests b/dev/run-pip-tests
index 1321c2be4c192..7271d1014e4ae 100755
--- a/dev/run-pip-tests
+++ b/dev/run-pip-tests
@@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do
source "$VIRTUALENV_PATH"/bin/activate
fi
# Upgrade pip & friends if using virutal env
- if [ ! -n "USE_CONDA" ]; then
+ if [ ! -n "$USE_CONDA" ]; then
pip install --upgrade pip pypandoc wheel numpy
fi
diff --git a/dev/run-tests.py b/dev/run-tests.py
index fe75ef4411c8c..cd4590864b7d7 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -204,7 +204,7 @@ def run_scala_style_checks():
def run_java_style_checks():
set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE")
- run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")])
+ run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")])
def run_python_style_checks():
@@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version):
exec_sbt(profiles_and_goals)
-def build_spark_assembly_sbt(hadoop_version):
+def build_spark_assembly_sbt(hadoop_version, checkstyle=False):
# Enable all of the profiles for the build:
build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
sbt_goals = ["assembly/package"]
@@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version):
" ".join(profiles_and_goals))
exec_sbt(profiles_and_goals)
+ if checkstyle:
+ run_java_style_checks()
+
# Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build.
# Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the
# documentation build fails on a specific machine & environment in Jenkins but it was unable
@@ -570,12 +573,13 @@ def main():
or f.endswith("scalastyle-config.xml")
for f in changed_files):
run_scala_style_checks()
+ should_run_java_style_checks = False
if not changed_files or any(f.endswith(".java")
or f.endswith("checkstyle.xml")
or f.endswith("checkstyle-suppressions.xml")
for f in changed_files):
- # run_java_style_checks()
- pass
+ # Run SBT Checkstyle after the build to prevent a side-effect to the build.
+ should_run_java_style_checks = True
if not changed_files or any(f.endswith("lint-python")
or f.endswith("tox.ini")
or f.endswith(".py")
@@ -604,7 +608,7 @@ def main():
detect_binary_inop_with_mima(hadoop_version)
# Since we did not build assembly/package before running dev/mima, we need to
# do it here because the tests still rely on it; see SPARK-13294 for details.
- build_spark_assembly_sbt(hadoop_version)
+ build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks)
# run the test suites
run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags)
@@ -621,7 +625,7 @@ def _test():
import doctest
failure_count = doctest.testmod()[0]
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle
new file mode 100755
index 0000000000000..8821a7c0e4ccf
--- /dev/null
+++ b/dev/sbt-checkstyle
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file
+# with failure (either resolution or compilation); the "q" makes SBT quit.
+ERRORS=$(echo -e "q\n" \
+ | build/sbt \
+ -Pkinesis-asl \
+ -Pmesos \
+ -Pkafka-0-8 \
+ -Pkubernetes \
+ -Pyarn \
+ -Pflume \
+ -Phive \
+ -Phive-thriftserver \
+ checkstyle test:checkstyle \
+ | awk '{if($1~/error/)print}' \
+)
+
+if test ! -z "$ERRORS"; then
+ echo -e "Checkstyle failed at following occurrences:\n$ERRORS"
+ exit 1
+else
+ echo -e "Checkstyle checks passed."
+fi
+
diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh
index 3bf7618e1ea96..2fbd6b5e98f7f 100755
--- a/dev/test-dependencies.sh
+++ b/dev/test-dependencies.sh
@@ -34,6 +34,7 @@ MVN="build/mvn"
HADOOP_PROFILES=(
hadoop-2.6
hadoop-2.7
+ hadoop-3.1
)
# We'll switch the version to a temp. one, publish POMs using that new version, then switch back to
diff --git a/dev/tox.ini b/dev/tox.ini
index 583c1eaaa966b..28dad8f3b5c7c 100644
--- a/dev/tox.ini
+++ b/dev/tox.ini
@@ -16,4 +16,4 @@
[pycodestyle]
ignore=E402,E731,E241,W503,E226,E722,E741,E305
max-line-length=100
-exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*
+exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/*
diff --git a/docs/README.md b/docs/README.md
index 225bb1b2040de..dbea4d64c4298 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel
Spark at http://spark.apache.org/documentation.html.
Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the
-documentation yourself. Why build it yourself? So that you have the docs that corresponds to
+documentation yourself. Why build it yourself? So that you have the docs that correspond to
whichever version of Spark you currently have checked out of revision control.
## Prerequisites
@@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb
$ sudo pip install Pygments
# Following is needed only for generating API docs
$ sudo pip install sphinx pypandoc mkdocs
-$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")'
+$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")'
+$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")'
```
-(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0)
+Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0.
+
+Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched.
## Generating the Documentation HTML
@@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build
## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs)
-You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory.
+You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory.
Similarly, you can build just the PySpark docs by running `make html` from the
-`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as
-public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and
-the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh`
+`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as
+public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and
+the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh`
after [building Spark](https://github.com/apache/spark#building-spark) first.
When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various
diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb
index 6ea1d438f529e..1e91f12518e0b 100644
--- a/docs/_plugins/include_example.rb
+++ b/docs/_plugins/include_example.rb
@@ -48,7 +48,7 @@ def render(context)
begin
code = File.open(@file).read.encode("UTF-8")
rescue => e
- # We need to explicitly exit on execptions here because Jekyll will silently swallow
+ # We need to explicitly exit on exceptions here because Jekyll will silently swallow
# them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104)
puts(e)
puts(e.backtrace)
diff --git a/docs/building-spark.md b/docs/building-spark.md
index c391255a91596..0236bb05849ad 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0.
## Building submodules individually
-It's possible to build Spark sub-modules using the `mvn -pl` option.
+It's possible to build Spark submodules using the `mvn -pl` option.
For instance, you can build the Spark Streaming module using:
diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md
index c150d9efc06ff..ac1c336988930 100644
--- a/docs/cloud-integration.md
+++ b/docs/cloud-integration.md
@@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION
All major cloud providers offer persistent data storage in *object stores*.
These are not classic "POSIX" file systems.
In order to store hundreds of petabytes of data without any single points of failure,
-object stores replace the classic filesystem directory tree
+object stores replace the classic file system directory tree
with a simpler model of `object-name => data`. To enable remote access, operations
on objects are usually offered as (slow) HTTP REST operations.
Spark can read and write data in object stores through filesystem connectors implemented
in Hadoop or provided by the infrastructure suppliers themselves.
-These connectors make the object stores look *almost* like filesystems, with directories and files
+These connectors make the object stores look *almost* like file systems, with directories and files
and the classic operations on them such as list, delete and rename.
diff --git a/docs/configuration.md b/docs/configuration.md
index e7f2419cc2fa4..6aa7878fe614d 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -208,7 +208,7 @@ of the most common options to set are:
stored on disk. This should be on a fast, local disk in your system. It can also be a
comma-separated list of multiple directories on different disks.
- NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or
+ NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or
LOCAL_DIRS (YARN) environment variables set by the cluster manager.
@@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful
Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this
option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file
used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory.
+
+ The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by
+ application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable
+ verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of:
+ -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc
@@ -451,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful
from JVM to Python worker for every task.
+
+
spark.sql.repl.eagerEval.enabled
+
false
+
+ Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
+ Dataset will be ran automatically. The HTML table which generated by _repl_html_
+ called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
+ REPL, the output will be shown like dataframe.show()
+ (see SPARK-24215 for more details).
+
+
+
+
spark.sql.repl.eagerEval.maxNumRows
+
20
+
+ Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text,
+ this only take effect when spark.sql.repl.eagerEval.enabled is set to true.
+
+
+
+
spark.sql.repl.eagerEval.truncate
+
20
+
+ Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or
+ plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true.
+
+
spark.files
@@ -558,7 +590,7 @@ Apart from these, the following properties are also available, and may be useful
This configuration limits the number of remote requests to fetch blocks at any given point.
When the number of hosts in the cluster increase, it might lead to very large number
- of in-bound connections to one or more nodes, causing the workers to fail under load.
+ of inbound connections to one or more nodes, causing the workers to fail under load.
By allowing it to limit the number of fetch requests, this scenario can be mitigated.
@@ -712,30 +744,6 @@ Apart from these, the following properties are also available, and may be useful
When we fail to register to the external shuffle service, we will retry for maxAttempts times.
-
-
spark.io.encryption.enabled
-
false
-
- Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption
- be enabled when using this feature.
-
-
-
-
spark.io.encryption.keySizeBits
-
128
-
- IO encryption key size in bits. Supported values are 128, 192 and 256.
-
-
-
-
spark.io.encryption.keygen.algorithm
-
HmacSHA1
-
- The algorithm to use when generating the IO encryption key. The supported algorithms are
- described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm
- Name Documentation.
-
-
### Spark UI
@@ -893,6 +901,23 @@ Apart from these, the following properties are also available, and may be useful
How many dead executors the Spark UI and status APIs remember before garbage collecting.
+
+
spark.ui.filters
+
None
+
+ Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a
+ standard
+ javax servlet Filter.
+
+ Filter parameters can also be specified in the configuration, by setting config entries
+ of the form spark.<class name of filter>.param.<param name>=<value>
+
+ For example:
+ spark.ui.filters=com.test.filter1
+ spark.com.test.filter1.param.name1=foo
+ spark.com.test.filter1.param.name2=bar
+
+
### Compression and Serialization
@@ -912,8 +937,8 @@ Apart from these, the following properties are also available, and may be useful
lz4
The codec used to compress internal data such as RDD partitions, event log, broadcast variables
- and shuffle outputs. By default, Spark provides three codecs: lz4, lzf,
- and snappy. You can also use fully qualified class names to specify the codec,
+ and shuffle outputs. By default, Spark provides four codecs: lz4, lzf,
+ snappy, and zstd. You can also use fully qualified class names to specify the codec,
e.g.
org.apache.spark.io.LZ4CompressionCodec,
org.apache.spark.io.LZFCompressionCodec,
@@ -1295,7 +1320,7 @@ Apart from these, the following properties are also available, and may be useful
4194304 (4 MB)
The estimated cost to open a file, measured by the number of bytes could be scanned at the same
- time. This is used when putting multiple files into a partition. It is better to over estimate,
+ time. This is used when putting multiple files into a partition. It is better to overestimate,
then the partitions with small files will be faster than partitions with bigger files.
@@ -1446,6 +1471,15 @@ Apart from these, the following properties are also available, and may be useful
Duration for an RPC remote endpoint lookup operation to wait before timing out.
+
+
spark.core.connection.ack.wait.timeout
+
spark.network.timeout
+
+ How long for the connection to wait for ack to occur before timing
+ out and giving up. To avoid unwilling timeout caused by long pause like GC,
+ you can set larger value.
+
+
### Scheduling
@@ -1511,7 +1545,7 @@ Apart from these, the following properties are also available, and may be useful
0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode
The minimum ratio of registered resources (registered resources / total expected resources)
- (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained
+ (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained
mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] )
to wait for before scheduling begins. Specified as a double between 0.0 and 1.0.
Regardless of whether the minimum ratio of resources has been reached,
@@ -1622,9 +1656,10 @@ Apart from these, the following properties are also available, and may be useful
spark.blacklist.killBlacklistedExecutors
false
- (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create,
- executors when they are blacklisted. Note that, when an entire node is added to the blacklist,
- all of the executors on that node will be killed.
+ (Experimental) If set to "true", allow Spark to automatically kill the executors
+ when they are blacklisted on fetch failure or blacklisted for the entire application,
+ as controlled by spark.blacklist.application.*. Note that, when an entire node is added
+ to the blacklist, all of the executors on that node will be killed.
@@ -1632,7 +1667,7 @@ Apart from these, the following properties are also available, and may be useful
false
(Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch
- failure happenes. If external shuffle service is enabled, then the whole node will be
+ failure happens. If external shuffle service is enabled, then the whole node will be
blacklisted.
@@ -1720,7 +1755,7 @@ Apart from these, the following properties are also available, and may be useful
When spark.task.reaper.enabled = true, this setting specifies a timeout after
which the executor JVM will kill itself if a killed task has not stopped running. The default
value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose
- of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering
+ of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering
an executor unusable.
@@ -1751,6 +1786,7 @@ Apart from these, the following properties are also available, and may be useful
spark.dynamicAllocation.minExecutors,
spark.dynamicAllocation.maxExecutors, and
spark.dynamicAllocation.initialExecutors
+ spark.dynamicAllocation.executorAllocationRatio
@@ -1795,6 +1831,23 @@ Apart from these, the following properties are also available, and may be useful
Lower bound for the number of executors if dynamic allocation is enabled.
+
+
spark.dynamicAllocation.executorAllocationRatio
+
1
+
+ By default, the dynamic allocation will request enough executors to maximize the
+ parallelism according to the number of tasks to process. While this minimizes the
+ latency of the job, with small tasks this setting can waste a lot of resources due to
+ executor allocation overhead, as some executor might not even do any work.
+ This setting allows to set a ratio that will be used to reduce the number of
+ executors w.r.t. full parallelism.
+ Defaults to 1.0 to give maximum parallelism.
+ 0.5 will divide the target number of executors by 2
+ The target number of executors computed by the dynamicAllocation can still be overriden
+ by the spark.dynamicAllocation.minExecutors and
+ spark.dynamicAllocation.maxExecutors settings
+
+
spark.dynamicAllocation.schedulerBacklogTimeout
1s
@@ -1817,313 +1870,8 @@ Apart from these, the following properties are also available, and may be useful
### Security
-
-
Property Name
Default
Meaning
-
-
spark.acls.enable
-
false
-
- Whether Spark acls should be enabled. If enabled, this checks to see if the user has
- access permissions to view or modify the job. Note this requires the user to be known,
- so if the user comes across as null no checks are done. Filters can be used with the UI
- to authenticate and set the user.
-
-
-
-
spark.admin.acls
-
Empty
-
- Comma separated list of users/administrators that have view and modify access to all Spark jobs.
- This can be used if you run on a shared cluster and have a set of administrators or devs who
- help debug when things do not work. Putting a "*" in the list means any user can have the
- privilege of admin.
-
-
-
-
spark.admin.acls.groups
-
Empty
-
- Comma separated list of groups that have view and modify access to all Spark jobs.
- This can be used if you have a set of administrators or developers who help maintain and debug
- the underlying infrastructure. Putting a "*" in the list means any user in any group can have
- the privilege of admin. The user groups are obtained from the instance of the groups mapping
- provider specified by spark.user.groups.mapping. Check the entry
- spark.user.groups.mapping for more details.
-
- The list of groups for a user is determined by a group mapping service defined by the trait
- org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property.
- A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider
- which can be specified to resolve a list of groups for a user.
- Note: This implementation supports only a Unix/Linux based environment. Windows environment is
- currently not supported. However, a new platform/protocol can be supported by implementing
- the trait org.apache.spark.security.GroupMappingServiceProvider.
-
-
-
-
spark.authenticate
-
false
-
- Whether Spark authenticates its internal connections. See
- spark.authenticate.secret if not running on YARN.
-
-
-
-
spark.authenticate.secret
-
None
-
- Set the secret key used for Spark to authenticate between components. This needs to be set if
- not running on YARN and authentication is enabled.
-
-
-
-
spark.network.crypto.enabled
-
false
-
- Enable encryption using the commons-crypto library for RPC and block transfer service.
- Requires spark.authenticate to be enabled.
-
-
-
-
spark.network.crypto.keyLength
-
128
-
- The length in bits of the encryption key to generate. Valid values are 128, 192 and 256.
-
-
-
-
spark.network.crypto.keyFactoryAlgorithm
-
PBKDF2WithHmacSHA1
-
- The key factory algorithm to use when generating encryption keys. Should be one of the
- algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used.
-
-
-
-
spark.network.crypto.saslFallback
-
true
-
- Whether to fall back to SASL authentication if authentication fails using Spark's internal
- mechanism. This is useful when the application is connecting to old shuffle services that
- do not support the internal Spark authentication protocol. On the server side, this can be
- used to block older clients from authenticating against a new shuffle service.
-
-
-
-
spark.network.crypto.config.*
-
None
-
- Configuration values for the commons-crypto library, such as which cipher implementations to
- use. The config name should be the name of commons-crypto configuration without the
- "commons.crypto" prefix.
-
-
-
-
spark.authenticate.enableSaslEncryption
-
false
-
- Enable encrypted communication when authentication is
- enabled. This is supported by the block transfer service and the
- RPC endpoints.
-
-
-
-
spark.network.sasl.serverAlwaysEncrypt
-
false
-
- Disable unencrypted connections for services that support SASL authentication.
-
-
-
-
spark.core.connection.ack.wait.timeout
-
spark.network.timeout
-
- How long for the connection to wait for ack to occur before timing
- out and giving up. To avoid unwilling timeout caused by long pause like GC,
- you can set larger value.
-
-
-
-
spark.modify.acls
-
Empty
-
- Comma separated list of users that have modify access to the Spark job. By default only the
- user that started the Spark job has access to modify it (kill it for example). Putting a "*" in
- the list means any user can have access to modify it.
-
-
-
-
spark.modify.acls.groups
-
Empty
-
- Comma separated list of groups that have modify access to the Spark job. This can be used if you
- have a set of administrators or developers from the same team to have access to control the job.
- Putting a "*" in the list means any user in any group has the access to modify the Spark job.
- The user groups are obtained from the instance of the groups mapping provider specified by
- spark.user.groups.mapping. Check the entry spark.user.groups.mapping
- for more details.
-
-
-
-
spark.ui.filters
-
None
-
- Comma separated list of filter class names to apply to the Spark web UI. The filter should be a
- standard
- javax servlet Filter. Parameters to each filter can also be specified by setting a
- java system property of:
- spark.<class name of filter>.params='param1=value1,param2=value2'
- For example:
- -Dspark.ui.filters=com.test.filter1
- -Dspark.com.test.filter1.params='param1=foo,param2=testing'
-
-
-
-
spark.ui.view.acls
-
Empty
-
- Comma separated list of users that have view access to the Spark web ui. By default only the
- user that started the Spark job has view access. Putting a "*" in the list means any user can
- have view access to this Spark job.
-
-
-
-
spark.ui.view.acls.groups
-
Empty
-
- Comma separated list of groups that have view access to the Spark web ui to view the Spark Job
- details. This can be used if you have a set of administrators or developers or users who can
- monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view
- the Spark job details on the Spark web ui. The user groups are obtained from the instance of the
- groups mapping provider specified by spark.user.groups.mapping. Check the entry
- spark.user.groups.mapping for more details.
-
-
-
-
-### TLS / SSL
-
-
-
Property Name
Default
Meaning
-
-
spark.ssl.enabled
-
false
-
- Whether to enable SSL connections on all supported protocols.
-
- When spark.ssl.enabled is configured, spark.ssl.protocol
- is required.
-
- All the SSL settings like spark.ssl.xxx where xxx is a
- particular configuration property, denote the global configuration for all the supported
- protocols. In order to override the global configuration for the particular protocol,
- the properties must be overwritten in the protocol-specific namespace.
-
- Use spark.ssl.YYY.XXX settings to overwrite the global configuration for
- particular protocol denoted by YYY. Example values for YYY
- include fs, ui, standalone, and
- historyServer. See SSL
- Configuration for details on hierarchical SSL configuration for services.
-
-
-
-
spark.ssl.[namespace].port
-
None
-
- The port where the SSL service will listen on.
-
- The port must be defined within a namespace configuration; see
- SSL Configuration for the available
- namespaces.
-
- When not set, the SSL port will be derived from the non-SSL port for the
- same service. A value of "0" will make the service bind to an ephemeral port.
-
-
-
-
spark.ssl.enabledAlgorithms
-
Empty
-
- A comma separated list of ciphers. The specified ciphers must be supported by JVM.
- The reference list of protocols one can find on
- this
- page.
- Note: If not set, it will use the default cipher suites of JVM.
-
-
-
-
spark.ssl.keyPassword
-
None
-
- A password to the private key in key-store.
-
-
-
-
spark.ssl.keyStore
-
None
-
- A path to a key-store file. The path can be absolute or relative to the directory where
- the component is started in.
-
-
-
-
spark.ssl.keyStorePassword
-
None
-
- A password to the key-store.
-
-
-
-
spark.ssl.keyStoreType
-
JKS
-
- The type of the key-store.
-
-
-
-
spark.ssl.protocol
-
None
-
- A protocol name. The protocol must be supported by JVM. The reference list of protocols
- one can find on this
- page.
-
-
-
-
spark.ssl.needClientAuth
-
false
-
- Set true if SSL needs client authentication.
-
-
-
-
spark.ssl.trustStore
-
None
-
- A path to a trust-store file. The path can be absolute or relative to the directory
- where the component is started in.
-
-
-
-
spark.ssl.trustStorePassword
-
None
-
- A password to the trust-store.
-
-
-
-
spark.ssl.trustStoreType
-
JKS
-
- The type of the trust-store.
-
-
-
-
+Please refer to the [Security](security.html) page for available options on how to secure different
+Spark subsystems.
### Spark SQL
@@ -2218,8 +1966,8 @@ showDF(properties, numRows = 200, truncate = FALSE)
spark.streaming.receiver.writeAheadLog.enable
false
- Enable write ahead logs for receivers. All the input data received through receivers
- will be saved to write ahead logs that will allow it to be recovered after driver failures.
+ Enable write-ahead logs for receivers. All the input data received through receivers
+ will be saved to write-ahead logs that will allow it to be recovered after driver failures.
See the deployment guide
in the Spark Streaming programing guide for more details.
- Whether to close the file after writing a write ahead log record on the driver. Set this to 'true'
+ Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true'
when you want to use S3 (or any file system that does not support flushing) for the metadata WAL
on the driver.
- Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true'
+ Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true'
when you want to use S3 (or any file system that does not support flushing) for the data WAL
on the receivers.
@@ -2481,7 +2229,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes
files are set cluster-wide, and cannot safely be changed by the application.
The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`.
-They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf`
+They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf`
In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For
instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties.
diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css
index 6247cd8396cf1..a4d583b366603 100644
--- a/docs/css/pygments-default.css
+++ b/docs/css/pygments-default.css
@@ -5,7 +5,7 @@ To generate this, I had to run
But first I had to install pygments via easy_install pygments
I had to override the conflicting bootstrap style rules by linking to
-this stylesheet lower in the html than the bootstap css.
+this stylesheet lower in the html than the bootstrap css.
Also, I was thrown off for a while at first when I was using markdown
code block inside my {% highlight scala %} ... {% endhighlight %} tags
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index 5c97a248df4bc..35293348e3f3d 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)(
The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices`
except that the user defined `map` function is applied to all vertices and can change the vertex
property type. Because not all vertices may have a matching value in the input RDD the `map`
-function takes an `Option` type. For example, we can setup a graph for PageRank by initializing
+function takes an `Option` type. For example, we can set up a graph for PageRank by initializing
vertex properties with their `outDegree`.
@@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be
# Examples
Suppose I want to build a graph from some text files, restrict the graph
-to important relationships and users, run page-rank on the sub-graph, and
+to important relationships and users, run page-rank on the subgraph, and
then finally return attributes associated with the top users. I can do
all of this in just a few lines with GraphX:
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index e6d881639a13b..da90342406c84 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y
different options to manage allocation, depending on the cluster manager.
The simplest option, available on all cluster managers, is _static partitioning_ of resources. With
-this approach, each application is given a maximum amount of resources it can use, and holds onto them
+this approach, each application is given a maximum amount of resources it can use and holds onto them
for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html)
and [YARN](running-on-yarn.html) modes, as well as the
[coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes).
@@ -230,7 +230,7 @@ properties:
* `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of
CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet
all active pools' minimum shares before redistributing extra resources according to the weights.
- The `minShare` property can therefore be another way to ensure that a pool can always get up to a
+ The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a
certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest
of the cluster. By default, each pool's `minShare` is 0.
diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md
index 2747f2df7cb10..375957e92cc4c 100644
--- a/docs/ml-advanced.md
+++ b/docs/ml-advanced.md
@@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for
L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical
solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively.
-In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead.
+In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead.
## Iteratively reweighted least squares (IRLS)
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index ddd2f4b49ca07..b3d109039da4d 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details.
[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All."
-`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.
+`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.
Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.
@@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat
## Naive Bayes
[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
-probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
-assumptions between the features. The `spark.ml` implementation currently supports both [multinomial
-naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
+probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence
+assumptions between every pair of features.
+
+Naive Bayes can be trained very efficiently. With a single pass over the training data,
+it computes the conditional probability distribution of each feature given each label.
+For prediction, it applies Bayes' theorem to compute the conditional probability distribution
+of each label given an observation.
+
+MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
-More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).
+
+*Input data*:
+These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+Within that context, each observation is a document and each feature represents a term.
+A feature's value is the frequency of the term (in multinomial Naive Bayes) or
+a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes).
+Feature values must be *non-negative*. The model type is selected with an optional parameter
+"multinomial" or "bernoulli" with "multinomial" as the default.
+For document classification, the input feature vectors should usually be sparse vectors.
+Since the training data is only used once, it is not necessary to cache it.
+
+[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
+setting the parameter $\lambda$ (default to $1.0$).
**Examples**
@@ -908,7 +926,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details.
belongs to the family of regression algorithms. Formally isotonic regression is a problem where
given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses
and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted
-finding a function that minimises
+finding a function that minimizes
`\begin{equation}
f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2
@@ -927,7 +945,7 @@ We implement a
which uses an approach to
[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10).
The training input is a DataFrame which contains three columns
-label, features and weight. Additionally IsotonicRegression algorithm has one
+label, features and weight. Additionally, IsotonicRegression algorithm has one
optional parameter called $isotonic$ defaulting to true.
This argument specifies if the isotonic regression is
isotonic (monotonically increasing) or antitonic (monotonically decreasing).
diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md
index 58f2d4b531e70..8b0f287dc39ad 100644
--- a/docs/ml-collaborative-filtering.md
+++ b/docs/ml-collaborative-filtering.md
@@ -35,7 +35,7 @@ but the ids must be within the integer value range.
### Explicit vs. implicit feedback
-The standard approach to matrix factorization based collaborative filtering treats
+The standard approach to matrix factorization-based collaborative filtering treats
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
for example, users giving ratings to movies.
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 3370eb3893272..7aed2341584fc 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -1174,7 +1174,7 @@ for more details on the API.
## SQLTransformer
`SQLTransformer` implements the transformations which are defined by SQL statement.
-Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."`
+Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."`
where `"__THIS__"` represents the underlying table of the input dataset.
The select clause specifies the fields, constants, and expressions to display in
the output, and can be any select clause that Spark SQL supports. Users can also
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 702bcf748fc74..aea07be34cb86 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases.
* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner
and better accommodate the addition of the multi-class summary. This is a breaking change for user
code that casts a `LogisticRegressionTrainingSummary` to a
-` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
+`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary`
method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail
(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which
will still work correctly for both multinomial and binary cases.
diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md
index f4b0df58cf63b..e4736411fb5fe 100644
--- a/docs/ml-migration-guides.md
+++ b/docs/ml-migration-guides.md
@@ -347,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method
separate classification and regression, and they replace specialized parameter types with
simple `String` types.
-Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the
+Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the
[Decision Trees Guide](mllib-decision-tree.html#examples).
## From 0.9 to 1.0
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md
index aa92c0a37c0f4..e22e9003c30f6 100644
--- a/docs/ml-pipeline.md
+++ b/docs/ml-pipeline.md
@@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s.
For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`.
This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`.
-## Saving and Loading Pipelines
+## ML persistence: Saving and Loading Pipelines
-Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported.
+Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API.
+As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage.
+
+ML persistence works across Scala, Java and Python. However, R currently uses a modified format,
+so models saved in R can only be loaded back in R; this should be fixed in the future and is
+tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572).
+
+### Backwards compatibility for ML persistence
+
+In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML
+model or Pipeline in one version of Spark, then you should be able to load it back and use it in a
+future version of Spark. However, there are rare exceptions, described below.
+
+Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark
+version X loadable by Spark version Y?
+
+* Major versions: No guarantees, but best-effort.
+* Minor and patch versions: Yes; these are backwards compatible.
+* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible.
+
+Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y?
+
+* Major versions: No guarantees, but best-effort.
+* Minor and patch versions: Identical behavior, except for bug fixes.
+
+For both model persistence and model behavior, any breaking changes across a minor version or patch
+version are reported in the Spark version release notes. If a breakage is not reported in release
+notes, then it should be treated as a bug to be fixed.
# Code examples
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md
index 54d9cd21909df..028bfec465bab 100644
--- a/docs/ml-tuning.md
+++ b/docs/ml-tuning.md
@@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m
In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
`TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in
- the case of `CrossValidator`. It is therefore less expensive,
+ the case of `CrossValidator`. It is, therefore, less expensive,
but will not produce as reliable results when the training dataset is not sufficiently large.
Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair.
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index df2be92d860e4..dc6b095f5d59b 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`.
In the following example after loading and parsing data, we use the
[`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data
into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within
-Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the
+Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the
optimal *k* is usually one where there is an "elbow" in the WSSSE graph.
Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API.
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index 76a00f18b3b90..b2300028e151b 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -31,7 +31,7 @@ following parameters:
### Explicit vs. implicit feedback
-The standard approach to matrix factorization based collaborative filtering treats
+The standard approach to matrix factorization-based collaborative filtering treats
the entries in the user-item matrix as *explicit* preferences given by the user to the item,
for example, users giving ratings to movies.
@@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
-In the following example we load rating data. Each row consists of a user, a product and a rating.
+In the following example, we load rating data. Each row consists of a user, a product and a rating.
We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$)
method which assumes ratings are explicit. We evaluate the
recommendation model by measuring the Mean Squared Error of rating prediction.
diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md
index 35cee3275e3b5..5066bb29387dc 100644
--- a/docs/mllib-data-types.md
+++ b/docs/mllib-data-types.md
@@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`.
***Note***
The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size.
-In general the use of non-deterministic RDDs can lead to errors.
+In general, the use of non-deterministic RDDs can lead to errors.
### RowMatrix
diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md
index a72680d52a26c..4e6b4530942f1 100644
--- a/docs/mllib-dimensionality-reduction.md
+++ b/docs/mllib-dimensionality-reduction.md
@@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an
[Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a
statistical method to find a rotation such that the first coordinate has the largest variance
-possible, and each succeeding coordinate in turn has the largest variance possible. The columns of
+possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of
the rotation matrix are called principal components. PCA is used widely in dimensionality reduction.
`spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors.
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index 7f277543d2e9a..d9dbbab4840a3 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem
suite of metrics for the purpose of evaluating the performance of machine learning models.
Specific machine learning algorithms fall under broader types of machine learning applications like classification,
-regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those
+regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those
metrics that are currently available in `spark.mllib` are detailed in this section.
## Classification model evaluation
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 75aea70601875..bb29f65c0322f 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top
\]`
where $V$ is the vocabulary size.
-The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$
+The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$
is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec,
we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to
$O(\log(V))$
@@ -278,8 +278,8 @@ for details on the API.
multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This
represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29)
between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector.
-Qu8T948*1#
-Denoting the `scalingVec` as "`w`," this transformation may be written as:
+
+Denoting the `scalingVec` as "`w`", this transformation may be written as:
`\[ \begin{pmatrix}
v_1 \\
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index ca84551506b2b..99cab98c690c6 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API
belongs to the family of regression algorithms. Formally isotonic regression is a problem where
given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses
and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted
-finding a function that minimises
+finding a function that minimizes
`\begin{equation}
f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2
@@ -28,7 +28,7 @@ best fitting the original data points.
which uses an approach to
[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10).
The training input is an RDD of tuples of three double values that represent
-label, feature and weight in this order. Additionally IsotonicRegression algorithm has one
+label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one
optional parameter called $isotonic$ defaulting to true.
This argument specifies if the isotonic regression is
isotonic (monotonically increasing) or antitonic (monotonically decreasing).
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 034e89e25000e..73f6e206ca543 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams
testing then start the job. Printing predictions alongside true labels lets us easily see the
result.
-Finally we can save text files with data to the training or testing folders.
+Finally, we can save text files with data to the training or testing folders.
Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label
and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)`
the model will update. Anytime a text file is placed in `args(1)` you will see predictions.
diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md
index 14d76a6e41e23..04758903da89c 100644
--- a/docs/mllib-optimization.md
+++ b/docs/mllib-optimization.md
@@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor
standard spark routines.
If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in
-each iteration is exact (sub)gradient descent. In this case there is no randomness and no
+each iteration is exact (sub)gradient descent. In this case, there is no randomness and no
variance in the used step directions.
On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point
is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to
@@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble
quadratic without evaluating the second partial derivatives of the objective function to construct the
Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no
vertical scalability issue (the number of training features) when computing the Hessian matrix
-explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with
+explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with
other first-order optimization.
### Choosing an Optimization Method
diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md
index d3530908706d0..f567565437927 100644
--- a/docs/mllib-pmml-model-export.md
+++ b/docs/mllib-pmml-model-export.md
@@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API
* Table of contents
{:toc}
-## `spark.mllib` supported models
+## spark.mllib supported models
`spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)).
@@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a
-
`spark.mllib` model
PMML model
+
spark.mllib model
PMML model
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 6f6cfc1288d73..6eaf33135744d 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -80,7 +80,10 @@ The history server can be configured as follows:
-### Spark configuration options
+### Spark History Server Configuration Options
+
+Security options for the Spark History Server are covered more detail in the
+[Security](security.html#web-ui) page.
Property Name
Default
Meaning
@@ -160,41 +163,6 @@ The history server can be configured as follows:
Location of the kerberos keytab file for the History Server.
-
-
spark.history.ui.acls.enable
-
false
-
- Specifies whether acls should be checked to authorize users viewing the applications.
- If enabled, access control checks are made regardless of what the individual application had
- set for spark.ui.acls.enable when the application was run. The application owner
- will always have authorization to view their own application and any users specified via
- spark.ui.view.acls and groups specified via spark.ui.view.acls.groups
- when the application was run will also have authorization to view that application.
- If disabled, no access control checks are made.
-
-
-
-
spark.history.ui.admin.acls
-
empty
-
- Comma separated list of users/administrators that have view access to all the Spark applications in
- history server. By default only the users permitted to view the application at run-time could
- access the related application history, with this, configured users/administrators could also
- have the permission to access it.
- Putting a "*" in the list means any user can have the privilege of admin.
-
-
-
-
spark.history.ui.admin.acls.groups
-
empty
-
- Comma separated list of groups that have view access to all the Spark applications in
- history server. By default only the groups permitted to view the application at run-time could
- access the related application history, with this, configured groups could also
- have the permission to access it.
- Putting a "*" in the list means any group can have the privilege of admin.
-
-
spark.history.fs.cleaner.enabled
false
@@ -246,7 +214,7 @@ incomplete attempt or the final successful attempt.
2. Incomplete applications are only updated intermittently. The time between updates is defined
by the interval between checks for changed files (`spark.history.fs.update.interval`).
-On larger clusters the update interval may be set to large values.
+On larger clusters, the update interval may be set to large values.
The way to view a running application is actually to view its own web UI.
3. Applications which exited without registering themselves as completed will be listed
@@ -347,6 +315,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running
/applications/[app-id]/executors
A list of all active executors for the given application.
+ Stack traces of all the threads running within the given active executor.
+ Not available via the history server.
+
+
/applications/[app-id]/allexecutors
A list of all(active and dead) executors for the given application.
@@ -447,7 +422,7 @@ configuration property.
If, say, users wanted to set the metrics namespace to the name of the application, they
can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is
then expanded appropriately by Spark and is used as the root namespace of the metrics system.
-Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the
+Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the
`spark.metrics.namespace` property have any such affect on such metrics.
Spark's metrics are decoupled into different
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 07c520cbee6be..f1a2096cd4dbd 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd
interactive shell (in Python or Scala),
then show how to write applications in Java, Scala, and Python.
-To follow along with this guide, first download a packaged release of Spark from the
+To follow along with this guide, first, download a packaged release of Spark from the
[Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS,
you can download a package for any version of Hadoop.
-Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset.
+Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset.
# Interactive Analysis with the Spark Shell
@@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset
res1: String = # Apache Spark
{% endhighlight %}
-Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file.
+Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file.
{% highlight scala %}
scala> val linesWithSpark = textFile.filter(line => line.contains("Spark"))
diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md
index 2e29aef7f21a2..b6424090d2fea 100644
--- a/docs/rdd-programming-guide.md
+++ b/docs/rdd-programming-guide.md
@@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex
The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure.
-In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it.
+In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it.
To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail.
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 3c7586e8544ba..408e446ea4822 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -17,7 +17,7 @@ container images and entrypoints.**
* A runnable distribution of Spark 2.3 or above.
* A running Kubernetes cluster at version >= 1.6 with access configured to it using
[kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster,
-you may setup a test cluster on your local machine using
+you may set up a test cluster on your local machine using
[minikube](https://kubernetes.io/docs/getting-started-guides/minikube/).
* We recommend using the latest release of minikube with the DNS addon enabled.
* Be aware that the default minikube configuration is not enough for running Spark applications.
@@ -126,29 +126,6 @@ Those dependencies can be added to the classpath by referencing them with `local
dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission
client's local file system is currently not yet supported.
-
-### Using Remote Dependencies
-When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods
-need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading
-the dependencies so the driver and executor containers can use them locally.
-
-The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and
-`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g.,
-the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command:
-
-```bash
-$ bin/spark-submit \
- --master k8s://https://: \
- --deploy-mode cluster \
- --name spark-pi \
- --class org.apache.spark.examples.SparkPi \
- --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar
- --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2
- --conf spark.executor.instances=5 \
- --conf spark.kubernetes.container.image= \
- https://path/to/examples.jar
-```
-
## Secret Management
Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a
Spark application to access secured services. To mount a user-specified secret into the driver container, users can use
@@ -163,9 +140,11 @@ namespace as that of the driver and executor pods. For example, to mount a secre
--conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets
```
-Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the
-init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the
-init-container of the executor.
+To use a secret through an environment variable use the following options to the `spark-submit` command:
+```
+--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key
+--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key
+```
## Introspection and Debugging
@@ -248,7 +227,7 @@ that allows driver pods to create pods and services under the default Kubernetes
[RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom
service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to
be used by the driver pod through the configuration property
-`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod
+`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod
use the `spark` service account, a user simply adds the following option to the `spark-submit` command:
```
@@ -291,7 +270,6 @@ future versions of the spark-kubernetes integration.
Some of these include:
-* PySpark
* R
* Dynamic Executor Scaling
* Local File Dependency Management
@@ -348,6 +326,13 @@ specific to Spark on Kubernetes.
Container image pull policy used when pulling images within Kubernetes.
+
+
spark.kubernetes.container.image.pullSecrets
+
+
+ Comma separated list of Kubernetes secrets used to pull images from private image registries.
+
+
spark.kubernetes.allocation.batch.size
5
@@ -576,14 +561,23 @@ specific to Spark on Kubernetes.
spark.kubernetes.driver.limit.cores
(none)
- Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod.
+ Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod.
+
+
spark.kubernetes.executor.request.cores
+
(none)
+
+ Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu).
+ Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units).
+ This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task
+ parallelism, e.g., number of tasks an executor can run concurrently is not affected by this.
+
spark.kubernetes.executor.limit.cores
(none)
- Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application.
+ Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application.
@@ -605,59 +599,50 @@ specific to Spark on Kubernetes.
- Location to download jars to in the driver and executors.
- This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods.
-
- Location to download jars to in the driver and executors.
- This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods.
+ Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example,
+ spark.kubernetes.driver.secrets.spark-secret=/etc/secrets.
-
spark.kubernetes.mountDependencies.timeout
-
300s
+
spark.kubernetes.executor.secrets.[SecretName]
+
(none)
- Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into
- the driver and executor pods.
+ Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example,
+ spark.kubernetes.executor.secrets.spark-secret=/etc/secrets.
- Maximum number of remote dependencies to download simultaneously in a driver or executor pod.
+ Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example,
+ spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key.
-
spark.kubernetes.initContainer.image
-
(value of spark.kubernetes.container.image)
+
spark.kubernetes.executor.secretKeyRef.[EnvName]
+
(none)
- Custom container image for the init container of both driver and executors.
+ Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example,
+ spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key.
-
spark.kubernetes.driver.secrets.[SecretName]
-
(none)
+
spark.kubernetes.memoryOverheadFactor
+
0.1
- Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example,
- spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used,
- the secret will also be added to the init-container in the driver pod.
+ This sets the Memory Overhead Factor that will allocate memory to non-JVM memory, which includes off-heap memory allocations, non-JVM tasks, and various systems processes. For JVM-based jobs this value will default to 0.10 and 0.40 for non-JVM jobs.
+ This is done as non-JVM tasks need more non-JVM heap space and such tasks commonly fail with "Memory Overhead Exceeded" errors. This prempts this error with a higher default.
-
spark.kubernetes.executor.secrets.[SecretName]
-
(none)
+
spark.kubernetes.pyspark.pythonversion
+
"2"
- Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example,
- spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used,
- the secret will also be added to the init-container in the executor pod.
+ This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3.
-
\ No newline at end of file
+
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 2bb5ecf1b8509..66ffb17949845 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos.
Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure
`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location.
+## Authenticating to Mesos
+
+When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework.
+
+Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/).
+
+Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials.
+
+Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`.
+
+### Credential Specification Preference Order
+
+Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored:
+
+- `spark.mesos.principal` configuration setting
+- `SPARK_MESOS_PRINCIPAL` environment variable
+- `spark.mesos.principal.file` configuration setting
+- `SPARK_MESOS_PRINCIPAL_FILE` environment variable
+
+An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables.
+
## Uploading Spark Package
When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary
@@ -204,7 +225,7 @@ details and default values.
Executors are brought up eagerly when the application starts, until
`spark.cores.max` is reached. If you don't set `spark.cores.max`, the
Spark application will consume all resources offered to it by Mesos,
-so we of course urge you to set this variable in any sort of
+so we, of course, urge you to set this variable in any sort of
multi-tenant cluster, including one which runs multiple concurrent
Spark applications.
@@ -212,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos
gives it, but there are no spread guarantees, as Mesos does not
provide such guarantees on the offer stream.
-In this mode spark executors will honor port allocation if such is
-provided from the user. Specifically if the user defines
+In this mode Spark executors will honor port allocation if such is
+provided from the user. Specifically, if the user defines
`spark.blockManager.port` in Spark configuration,
the mesos scheduler will check the available offers for a valid port
range containing the port numbers. If no such range is available it will
not launch any task. If no restriction is imposed on port numbers by the
user, ephemeral ports are used as usual. This port honouring implementation
-implies one task per host if the user defines a port. In the future network
+implies one task per host if the user defines a port. In the future network,
isolation shall be supported.
The benefit of coarse-grained mode is much lower startup overhead, but
@@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.principal
(none)
- Set the principal with which Spark framework will use to authenticate with Mesos.
+ Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`.
+
+
+
+
spark.mesos.principal.file
+
(none)
+
+ Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`.
@@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config
(none)
Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when
- authenticating with the registry.
+ authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`.
+
+
+
+
spark.mesos.secret.file
+
(none)
+
+ Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when
+ authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`.
@@ -450,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.constraints
(none)
- Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting
+ Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting
applies only to executors. Refer to Mesos
Attributes & Resources for more information on attributes.
@@ -717,6 +753,18 @@ See the [configuration page](configuration.html) for information on Spark config
spark.cores.max is reached
+
+
spark.mesos.appJar.local.resolution.mode
+
host
+
+ Provides support for the `local:///` scheme to reference the app jar resource in cluster mode.
+ If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg.
+ the mesos fetcher tries to get the resource from the host's file system.
+ If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`.
+ If the value is `container` then spark submit in the container will use the jar in the container's path:
+ `/path/to/jar`.
+
+
# Troubleshooting and Debugging
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c010af35f8d2e..4dbcbeafbbd9d 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -2,6 +2,8 @@
layout: global
title: Running Spark on YARN
---
+* This will become a table of contents (this text will be scraped).
+{:toc}
Support for running on [YARN (Hadoop
NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html)
@@ -131,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd
spark.yarn.am.waitTime
100s
- In cluster mode, time for the YARN Application Master to wait for the
- SparkContext to be initialized. In client mode, time for the YARN Application Master to wait
- for the driver to connect to it.
+ Only used in cluster mode. Time for the YARN Application Master to wait for the
+ SparkContext to be initialized.
@@ -217,8 +218,8 @@ To use a custom metrics.properties for the application master and executors, upd
spark.yarn.dist.forceDownloadSchemes
(none)
- Comma-separated list of schemes for which files will be downloaded to the local disk prior to
- being added to YARN's distributed cache. For use in cases where the YARN service does not
+ Comma-separated list of schemes for which files will be downloaded to the local disk prior to
+ being added to YARN's distributed cache. For use in cases where the YARN service does not
support schemes that are supported by Spark, like http, https and ftp.
@@ -265,19 +266,6 @@ To use a custom metrics.properties for the application master and executors, upd
distribution.
-
-
spark.yarn.access.hadoopFileSystems
-
(none)
-
- A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For
- example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032,
- webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed
- and Kerberos must be properly configured to be able to access them (either in the same realm
- or in a trusted realm). Spark acquires security tokens for each of the filesystems so that
- the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes
- is deprecated, please use this instead.
-
-
spark.yarn.appMasterEnv.[EnvironmentVariableName]
(none)
@@ -373,31 +361,6 @@ To use a custom metrics.properties for the application master and executors, upd
in YARN ApplicationReports, which can be used for filtering when querying YARN apps.
-
-
spark.yarn.keytab
-
(none)
-
- The full path to the file that contains the keytab for the principal specified above.
- This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache,
- for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master)
-
-
-
-
spark.yarn.principal
-
(none)
-
- Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master)
-
-
-
-
spark.yarn.kerberos.relogin.period
-
1m
-
- How often to check whether the kerberos TGT should be renewed. This should be set to a value
- that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled).
- The default value should be enough for most deployments.
-
-
spark.yarn.config.gatewayPath
(none)
@@ -424,17 +387,6 @@ To use a custom metrics.properties for the application master and executors, upd
See spark.yarn.config.gatewayPath.
-
-
spark.security.credentials.${service}.enabled
-
true
-
- Controls whether to obtain credentials for services when security is enabled.
- By default, credentials for all supported services are retrieved when those services are
- configured, but it's possible to disable that behavior if it somehow conflicts with the
- application being run. For further details please see
- [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster)
-
-
spark.yarn.rolledLog.includePattern
(none)
@@ -465,51 +417,110 @@ To use a custom metrics.properties for the application master and executors, upd
- Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured.
- In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do.
-- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN.
+- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN.
- The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
-# Running in a Secure Cluster
+# Kerberos
+
+Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page.
-As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to
-authenticate principals associated with services and clients. This allows clients to
-make requests of these authenticated services; the services to grant rights
-to the authenticated principals.
+In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens
+for:
-Hadoop services issue *hadoop tokens* to grant access to the services and data.
-Clients must first acquire tokens for the services they will access and pass them along with their
-application as it is launched in the YARN cluster.
+- the filesystem hosting the staging directory of the Spark application (which is the default
+ filesystem if `spark.yarn.stagingDir` is not set);
+- if Hadoop federation is enabled, all the federated filesystems in the configuration.
-For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens
-using the Kerberos credentials of the user launching the application
-—that is, the principal whose identity will become that of the launched Spark application.
+If an application needs to interact with other secure Hadoop filesystems, their URIs need to be
+explicitly provided to Spark at launch time. This is done by listing them in the
+`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below.
-This is normally done at launch time: in a secure cluster Spark will automatically obtain a
-token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive.
+The YARN integration also supports custom delegation token providers using the Java Services
+mechanism (see `java.util.ServiceLoader`). Implementations of
+`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark
+by listing their names in the corresponding file in the jar's `META-INF/services` directory. These
+providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to
+`false`, where `{service}` is the name of the credential provider.
-An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares
-the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`),
-and `spark.security.credentials.hbase.enabled` is not set to `false`.
+## YARN-specific Kerberos Configuration
-Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration
-includes a URI of the metadata store in `"hive.metastore.uris`, and
-`spark.security.credentials.hive.enabled` is not set to `false`.
+
+
Property Name
Default
Meaning
+
+
spark.yarn.keytab
+
(none)
+
+ The full path to the file that contains the keytab for the principal specified above. This keytab
+ will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and
+ will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to
+ the --keytab command line argument.
-If an application needs to interact with other secure Hadoop filesystems, then
-the tokens needed to access these clusters must be explicitly requested at
-launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property.
+ (Works also with the "local" master.)
+
+
+
+
spark.yarn.principal
+
(none)
+
+ Principal to be used to login to KDC, while running on secure clusters. Equivalent to the
+ --principal command line argument.
+
+ (Works also with the "local" master.)
+
+
+
+
spark.yarn.access.hadoopFileSystems
+
(none)
+
+ A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For
+ example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032,
+ webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed
+ and Kerberos must be properly configured to be able to access them (either in the same realm
+ or in a trusted realm). Spark acquires security tokens for each of the filesystems so that
+ the Spark application can access those remote Hadoop filesystems.
+
+
+
+
spark.yarn.kerberos.relogin.period
+
1m
+
+ How often to check whether the kerberos TGT should be renewed. This should be set to a value
+ that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled).
+ The default value should be enough for most deployments.
+
+
+
+
+## Troubleshooting Kerberos
+
+Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to
+enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG`
+environment variable.
+
+```bash
+export HADOOP_JAAS_DEBUG=true
+```
+
+The JDK classes can be configured to enable extra logging of their Kerberos and
+SPNEGO/REST authentication via the system properties `sun.security.krb5.debug`
+and `sun.security.spnego.debug=true`
```
-spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/
+-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true
```
-Spark supports integrating with other security-aware services through Java Services mechanism (see
-`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider`
-should be available to Spark by listing their names in the corresponding file in the jar's
-`META-INF/services` directory. These plug-ins can be disabled by setting
-`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of
-credential provider.
+All these options can be enabled in the Application Master:
-## Configuring the External Shuffle Service
+```
+spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true
+spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true
+```
+
+Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log
+will include a list of all tokens obtained, and their expiry details
+
+
+# Configuring the External Shuffle Service
To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these
instructions:
@@ -542,7 +553,7 @@ The following extra configuration options are available when the shuffle service
-## Launching your application with Apache Oozie
+# Launching your application with Apache Oozie
Apache Oozie can launch Spark applications as part of a workflow.
In a secure cluster, the launched application will need the relevant tokens to access the cluster's
@@ -576,35 +587,7 @@ spark.security.credentials.hbase.enabled false
The configuration option `spark.yarn.access.hadoopFileSystems` must be unset.
-## Troubleshooting Kerberos
-
-Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to
-enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG`
-environment variable.
-
-```bash
-export HADOOP_JAAS_DEBUG=true
-```
-
-The JDK classes can be configured to enable extra logging of their Kerberos and
-SPNEGO/REST authentication via the system properties `sun.security.krb5.debug`
-and `sun.security.spnego.debug=true`
-
-```
--Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true
-```
-
-All these options can be enabled in the Application Master:
-
-```
-spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true
-spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true
-```
-
-Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log
-will include a list of all tokens obtained, and their expiry details
-
-## Using the Spark History Server to replace the Spark Web UI
+# Using the Spark History Server to replace the Spark Web UI
It is possible to use the Spark History Server application page as the tracking URL for running
applications when the application UI is disabled. This may be desirable on secure clusters, or to
diff --git a/docs/security.md b/docs/security.md
index bebc28ddbfb0e..8c0c66fb5a285 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -3,41 +3,323 @@ layout: global
displayTitle: Spark Security
title: Security
---
+* This will become a table of contents (this text will be scraped).
+{:toc}
-Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows:
+# Spark RPC
-* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret.
-* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
+## Authentication
-## Web UI
+Spark currently supports authentication for RPC channels using a shared secret. Authentication can
+be turned on by setting the `spark.authenticate` configuration parameter.
-The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting
-and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration).
+The exact mechanism used to generate and distribute the shared secret is deployment-specific.
-### Authentication
+For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle
+generating and distributing the shared secret. Each application will use a unique shared secret. In
+the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of
+secrets to be secure.
-A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters.
+For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes.
+This secret will be shared by all the daemons and applications, so this deployment configuration is
+not as secure as the above, especially when considering multi-tenant clusters.
-Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces.
-Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications.
+
+
Property Name
Default
Meaning
+
+
spark.authenticate
+
false
+
Whether Spark authenticates its internal connections.
+
+
+
spark.authenticate.secret
+
None
+
+ The secret key used authentication. See above for when this configuration should be set.
+
+
+
+
+## Encryption
-## Event Logging
+Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC
+authentication must also be enabled and properly configured. AES encryption uses the
+[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's
+configuration system allows access to that library's configuration for advanced users.
-If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access.
+There is also support for SASL-based encryption, although it should be considered deprecated. It
+is still required when talking to shuffle services from Spark versions older than 2.2.0.
-## Encryption
+The following table describes the different options available for configuring this feature.
+
+
+
Property Name
Default
Meaning
+
+
spark.network.crypto.enabled
+
false
+
+ Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0.
+
+
+
+
spark.network.crypto.keyLength
+
128
+
+ The length in bits of the encryption key to generate. Valid values are 128, 192 and 256.
+
+
+
+
spark.network.crypto.keyFactoryAlgorithm
+
PBKDF2WithHmacSHA1
+
+ The key factory algorithm to use when generating encryption keys. Should be one of the
+ algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used.
+
+
+
+
spark.network.crypto.config.*
+
None
+
+ Configuration values for the commons-crypto library, such as which cipher implementations to
+ use. The config name should be the name of commons-crypto configuration without the
+ commons.crypto prefix.
+
+
+
+
spark.network.crypto.saslFallback
+
true
+
+ Whether to fall back to SASL authentication if authentication fails using Spark's internal
+ mechanism. This is useful when the application is connecting to old shuffle services that
+ do not support the internal Spark authentication protocol. On the shuffle service side,
+ disabling this feature will block older clients from authenticating.
+
+
+
+
spark.authenticate.enableSaslEncryption
+
false
+
+ Enable SASL-based encrypted communication.
+
+
+
+
spark.network.sasl.serverAlwaysEncrypt
+
false
+
+ Disable unencrypted connections for ports using SASL authentication. This will deny connections
+ from clients that have authentication enabled, but do not request SASL-based encryption.
+
+
+
+
+
+# Local Storage Encryption
+
+Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle
+spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover
+encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or
+`saveAsTable`.
+
+The following settings cover enabling encryption for data written to disk:
+
+
+
Property Name
Default
Meaning
+
+
spark.io.encryption.enabled
+
false
+
+ Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly
+ recommended that RPC encryption be enabled when using this feature.
+
+
+
+
spark.io.encryption.keySizeBits
+
128
+
+ IO encryption key size in bits. Supported values are 128, 192 and 256.
+
+
+
+
spark.io.encryption.keygen.algorithm
+
HmacSHA1
+
+ The algorithm to use when generating the IO encryption key. The supported algorithms are
+ described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm
+ Name Documentation.
+
+
+
+
spark.io.encryption.commons.config.*
+
None
+
+ Configuration values for the commons-crypto library, such as which cipher implementations to
+ use. The config name should be the name of commons-crypto configuration without the
+ commons.crypto prefix.
+
+
+
+
+
+# Web UI
+
+## Authentication and Authorization
+
+Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html).
+You will need a filter that implements the authentication method you want to deploy. Spark does not
+provide any built-in authentication filters.
+
+Spark also supports access control to the UI when an authentication filter is present. Each
+application can be configured with its own separate access control lists (ACLs). Spark
+differentiates between "view" permissions (who is allowed to see the application's UI), and "modify"
+permissions (who can do things like kill jobs in a running application).
+
+ACLs can be configured for either users or groups. Configuration entries accept comma-separated
+lists as input, meaning multiple users or groups can be given the desired privileges. This can be
+used if you run on a shared cluster and have a set of administrators or developers who need to
+monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL
+means that all users will have the respective pivilege. By default, only the user submitting the
+application is added to the ACLs.
+
+Group membership is established by using a configurable group mapping provider. The mapper is
+configured using the spark.user.groups.mapping config option, described in the table
+below.
+
+The following options control the authentication of Web UIs:
+
+
+
Property Name
Default
Meaning
+
+
spark.ui.filters
+
None
+
+ See the Spark UI configuration for how to configure
+ filters.
+
+
+
+
spark.acls.enable
+
false
+
+ Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access
+ permissions to view or modify the application. Note this requires the user to be authenticated,
+ so if no authentication filter is installed, this option does not do anything.
+
+
+
+
spark.admin.acls
+
None
+
+ Comma-separated list of users that have view and modify access to the Spark application.
+
+
+
+
spark.admin.acls.groups
+
None
+
+ Comma-separated list of groups that have view and modify access to the Spark application.
+
+
+
+
spark.modify.acls
+
None
+
+ Comma-separated list of users that have modify access to the Spark application.
+
+
+
+
spark.modify.acls.groups
+
None
+
+ Comma-separated list of groups that have modify access to the Spark application.
+
+
+
+
spark.ui.view.acls
+
None
+
+ Comma-separated list of users that have view access to the Spark application.
+
+
+
+
spark.ui.view.acls.groups
+
None
+
+ Comma-separated list of groups that have view access to the Spark application.
+
+ The list of groups for a user is determined by a group mapping service defined by the trait
+ org.apache.spark.security.GroupMappingServiceProvider, which can be configured by
+ this property.
+
+ By default, a Unix shell-based implementation is used, which collects this information
+ from the host OS.
+
+ Note: This implementation supports only Unix/Linux-based environments.
+ Windows environment is currently not supported. However, a new platform/protocol can
+ be supported by implementing the trait mentioned above.
+
+
+
+
+On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and
+control who has the respective privileges via YARN interfaces.
-Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service
-and the RPC endpoints. Shuffle files can also be encrypted if desired.
+## Spark History Server ACLs
-### SSL Configuration
+Authentication for the SHS Web UI is enabled the same way as for regular applications, using
+servlet filters.
+
+To enable authorization in the SHS, a few extra options are used:
+
+
+
Property Name
Default
Meaning
+
+
spark.history.ui.acls.enable
+
false
+
+ Specifies whether ACLs should be checked to authorize users viewing the applications in
+ the history server. If enabled, access control checks are performed regardless of what the
+ individual applications had set for spark.ui.acls.enable. The application owner
+ will always have authorization to view their own application and any users specified via
+ spark.ui.view.acls and groups specified via spark.ui.view.acls.groups
+ when the application was run will also have authorization to view that application.
+ If disabled, no access control checks are made for any application UIs available through
+ the history server.
+
+
+
+
spark.history.ui.admin.acls
+
None
+
+ Comma separated list of users that have view access to all the Spark applications in history
+ server.
+
+
+
+
spark.history.ui.admin.acls.groups
+
None
+
+ Comma separated list of groups that have view access to all the Spark applications in history
+ server.
+
+
+
+
+The SHS uses the same options to configure the group mapping provider as regular applications.
+In this case, the group mapping provider will apply to all UIs server by the SHS, and individual
+application configurations will be ignored.
+
+## SSL Configuration
Configuration for SSL is organized hierarchically. The user can configure the default SSL settings
which will be used for all the supported communication protocols unless they are overwritten by
protocol-specific settings. This way the user can easily provide the common settings for all the
-protocols without disabling the ability to configure each one individually. The common SSL settings
-are at `spark.ssl` namespace in Spark configuration. The following table describes the
-component-specific configuration namespaces used to override the default settings:
+protocols without disabling the ability to configure each one individually. The following table
+describes the the SSL configuration namespaces:
@@ -45,8 +327,11 @@ component-specific configuration namespaces used to override the default setting
Component
-
spark.ssl.fs
-
File download client (used to download jars and files from HTTPS-enabled servers).
+
spark.ssl
+
+ The default SSL configuration. These values will apply to all namespaces below, unless
+ explicitly overridden at the namespace level.
+
spark.ssl.ui
@@ -62,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
-The full breakdown of available SSL options can be found on the [configuration page](configuration.html).
-SSL must be configured on each node and configured for each component involved in communication using the particular protocol.
+The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be
+replaced with one of the above namespaces.
+
+
+
Property Name
Default
Meaning
+
+
${ns}.enabled
+
false
+
Enables SSL. When enabled, ${ns}.ssl.protocol is required.
+
+
+
${ns}.port
+
None
+
+ The port where the SSL service will listen on.
+
+ The port must be defined within a specific namespace configuration. The default
+ namespace is ignored when reading this configuration.
+
+ When not set, the SSL port will be derived from the non-SSL port for the
+ same service. A value of "0" will make the service bind to an ephemeral port.
+
+
+
+
${ns}.enabledAlgorithms
+
None
+
+ A comma-separated list of ciphers. The specified ciphers must be supported by JVM.
+
+ The reference list of protocols can be found in the "JSSE Cipher Suite Names" section
+ of the Java security guide. The list for Java 8 can be found at
+ this
+ page.
+
+ Note: If not set, the default cipher suite for the JRE will be used.
+
+
+
+
${ns}.keyPassword
+
None
+
+ The password to the private key in the key store.
+
+
+
+
${ns}.keyStore
+
None
+
+ Path to the key store file. The path can be absolute or relative to the directory in which the
+ process is started.
+
+
+
+
${ns}.keyStorePassword
+
None
+
Password to the key store.
+
+
+
${ns}.keyStoreType
+
JKS
+
The type of the key store.
+
+
+
${ns}.protocol
+
None
+
+ TLS protocol to use. The protocol must be supported by JVM.
+
+ The reference list of protocols can be found in the "Additional JSSE Standard Names"
+ section of the Java security guide. For Java 8, the list can be found at
+ this
+ page.
+
+
+
+
${ns}.needClientAuth
+
false
+
Whether to require client authentication.
+
+
+
${ns}.trustStore
+
None
+
+ Path to the trust store file. The path can be absolute or relative to the directory in which
+ the process is started.
+
+
+
+
${ns}.trustStorePassword
+
None
+
Password for the trust store.
+
+
+
${ns}.trustStoreType
+
JKS
+
The type of the trust store.
+
+
+
+## Preparing the key stores
+
+Key stores can be generated by `keytool` program. The reference documentation for this tool for
+Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html).
+The most basic steps to configure the key stores and the trust store for a Spark Standalone
+deployment mode is as follows:
+
+* Generate a key pair for each node
+* Export the public key of the key pair to a file on each node
+* Import all exported public keys into a single trust store
+* Distribute the trust store to the cluster nodes
### YARN mode
-The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark.
-For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS.
+To provide a local trust store or key store file to drivers running in cluster mode, they can be
+distributed with the application using the `--files` command line argument (or the equivalent
+`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS
+configuration should just reference the file name with no absolute path.
+
+Distributing local key stores this way may require the files to be staged in HDFS (or other similar
+distributed file system used by the cluster), so it's recommended that the undelying file system be
+configured with security in mind (e.g. by enabling authentication and wire encryption).
### Standalone mode
-The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors.
+
+The user needs to provide key stores and configuration options for master and workers. They have to
+be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in
+`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`.
+
+The user may allow the executors to use the SSL settings inherited from the worker process. That
+can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings
+provided by the user on the client side are not used.
### Mesos mode
-Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables.
+Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based
+secrets. Spark allows the specification of file-based and environment variable based secrets with
+`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively.
-### Preparing the key-stores
-Key-stores can be generated by `keytool` program. The reference documentation for this tool is
-[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic
-steps to configure the key-stores and the trust-store for the standalone deployment mode is as
-follows:
+Depending on the secret store backend secrets can be passed by reference or by value with the
+`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties,
+respectively.
-* Generate a keys pair for each node
-* Export the public key of the key pair to a file on each node
-* Import all exported public keys into a single trust-store
-* Distribute the trust-store over the nodes
+Reference type secrets are served by the secret store and referred to by name, for example
+`/mysecret`. Value type secrets are passed on the command line and translated into their
+appropriate files or environment variables.
+
+## HTTP Security Headers
-### Configuring SASL Encryption
+Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting
+(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport
+Security.
-SASL encryption is currently supported for the block transfer service when authentication
-(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set
-`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration.
+
+
Property Name
Default
Meaning
+
+
spark.ui.xXssProtection
+
1; mode=block
+
+ Value for HTTP X-XSS-Protection response header. You can choose appropriate value
+ from below:
+
+
0 (Disables XSS filtering)
+
1 (Enables XSS filtering. If a cross-site scripting attack is detected,
+ the browser will sanitize the page.)
+
1; mode=block (Enables XSS filtering. The browser will prevent rendering
+ of the page if an attack is detected.)
+
+
+
+
+
spark.ui.xContentTypeOptions.enabled
+
true
+
+ When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff".
+
+
+
+
spark.ui.strictTransportSecurity
+
None
+
+ Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate
+ value from below and set expire-time accordingly. This option is only used when
+ SSL/TLS is enabled.
+
+
max-age=<expire-time>
+
max-age=<expire-time>; includeSubDomains
+
max-age=<expire-time>; preload
+
+
+
+
-When using an external shuffle service, it's possible to disable unencrypted connections by setting
-`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that
-option is enabled, applications that are not set up to use SASL encryption will fail to connect to
-the shuffle service.
-## Configuring Ports for Network Security
+# Configuring Ports for Network Security
Spark makes heavy use of the network, and some environments have strict requirements for using tight
firewall settings. Below are the primary ports that Spark uses for its communication and how to
configure those ports.
-### Standalone mode only
+## Standalone mode only
@@ -145,7 +586,7 @@ configure those ports.
-### All cluster managers
+## All cluster managers
@@ -186,54 +627,70 @@ configure those ports.
-### HTTP Security Headers
-Apache Spark can be configured to include HTTP Headers which aids in preventing Cross
-Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP
-Strict Transport Security.
+# Kerberos
+
+Spark supports submitting applications in environments that use Kerberos for authentication.
+In most cases, Spark relies on the credentials of the current logged in user when authenticating
+to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC
+with tools like `kinit`.
+
+When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local
+processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive
+and HBase.
+
+When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens
+for the service hosting the user's home directory.
+
+An HBase token will be obtained if HBase is in the application's classpath, and the HBase
+configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`).
+
+Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes
+URIs for remote metastore services (`hive.metastore.uris` is not empty).
+
+Delegation token support is currently only supported in YARN and Mesos modes. Consult the
+deployment-specific page for more information.
+
+The following options provides finer-grained control for this feature:
Property Name
Default
Meaning
-
spark.ui.xXssProtection
-
1; mode=block
-
- Value for HTTP X-XSS-Protection response header. You can choose appropriate value
- from below:
-
-
0 (Disables XSS filtering)
-
1 (Enables XSS filtering. If a cross-site scripting attack is detected,
- the browser will sanitize the page.)
-
1; mode=block (Enables XSS filtering. The browser will prevent rendering
- of the page if an attack is detected.)
-
-
-
-
-
spark.ui.xContentTypeOptions.enabled
+
spark.security.credentials.${service}.enabled
true
- When value is set to "true", X-Content-Type-Options HTTP response header will be set
- to "nosniff". Set "false" to disable.
-
-
-
-
spark.ui.strictTransportSecurity
-
None
-
- Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate
- value from below and set expire-time accordingly, when Spark is SSL/TLS enabled.
-
-
max-age=<expire-time>
-
max-age=<expire-time>; includeSubDomains
-
max-age=<expire-time>; preload
-
+ Controls whether to obtain credentials for services when security is enabled.
+ By default, credentials for all supported services are retrieved when those services are
+ configured, but it's possible to disable that behavior if it somehow conflicts with the
+ application being run.
-
-See the [configuration page](configuration.html) for more details on the security configuration
-parameters, and
-org.apache.spark.SecurityManager for implementation details about security.
+## Long-Running Applications
+
+Long-running applications may run into issues if their run time exceeds the maximum delegation
+token lifetime configured in services it needs to access.
+
+Spark supports automatically creating new tokens for these applications when running in YARN mode.
+Kerberos credentials need to be provided to the Spark application via the `spark-submit` command,
+using the `--principal` and `--keytab` parameters.
+
+The provided keytab will be copied over to the machine running the Application Master via the Hadoop
+Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured
+with encryption, at least.
+
+The Kerberos login will be periodically renewed using the provided credentials, and new delegation
+tokens for supported will be created.
+
+
+# Event Logging
+
+If your applications are using event logging, the directory where the event logs go
+(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files,
+the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory
+should correspond to the super user who is running the Spark History Server.
+This will allow all users to write to the directory but will prevent unprivileged users from
+reading, removing or renaming a file unless they own it. The event log files will be created by
+Spark with permissions such that only the user and group have read and write access.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 8fa643abf1373..14d742de5655c 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties:
especially if you run jobs very frequently.
+
+
spark.storage.cleanupFilesAfterExecutorExit
+
true
+
+ Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks,
+ spill files, etc) of worker directories following executor exits. Note that this doesn't
+ overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in
+ local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of
+ all files/subdirectories of a stopped and timeout application.
+ This only affects Standalone mode, support of other cluster manangers can be added in the future.
+
+
spark.worker.ui.compressedLogFileLengthCacheSize
100
@@ -338,7 +350,7 @@ worker during one single schedule iteration.
# Monitoring and Logging
-Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options.
+Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options.
In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console.
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 6685b585a393a..4faad2c4c1824 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s
With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources).
### From local data frames
-The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
+The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
-The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example
+The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example
to a Parquet file using `write.df`.
@@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50))
### Grouping, Aggregation
-SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
+SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
{% highlight r %}
@@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma
- The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected.
- For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`.
- A warning can be raised if versions of SparkR package and the Spark JVM do not match.
+
+## Upgrading to SparkR 2.3.1 and above
+
+ - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index a0e221b39cc34..4d8a738507bd1 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
-In Python it's possible to access a DataFrame's columns either by attribute
+In Python, it's possible to access a DataFrame's columns either by attribute
(`df.age`) or by indexing (`df['age']`). While the former is convenient for
interactive data exploration, users are highly encouraged to use the
latter form, which is future proof and won't break with column names that
@@ -278,7 +278,7 @@ the bytes back into an object.
Spark SQL supports two different methods for converting existing RDDs into Datasets. The first
method uses reflection to infer the schema of an RDD that contains specific types of objects. This
-reflection based approach leads to more concise code and works well when you already know the schema
+reflection-based approach leads to more concise code and works well when you already know the schema
while writing your Spark application.
The second method for creating Datasets is through a programmatic interface that allows you to
@@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
Sets the compression codec used when writing Parquet files. If either `compression` or
`parquet.compression` is specified in the table-specific options/properties, the precedence would be
`compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include:
- none, uncompressed, snappy, gzip, lzo.
+ none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd.
@@ -1004,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
+## ORC Files
+
+Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files.
+To do that, the following configurations are newly added. The vectorized reader is used for the
+native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl`
+is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC
+serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`),
+the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`.
+
+
+
Property Name
Default
Meaning
+
+
spark.sql.orc.impl
+
native
+
The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1.
+
+
+
spark.sql.orc.enableVectorizedReader
+
true
+
Enables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
+
+
+
## JSON Datasets
@@ -1191,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used
1.2.1
Version of the Hive metastore. Available
- options are 0.12.0 through 1.2.1.
+ options are 0.12.0 through 2.3.3.
@@ -1220,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used
- A comma separated list of class prefixes that should be loaded using the classloader that is
+ A comma-separated list of class prefixes that should be loaded using the classloader that is
shared between Spark SQL and a specific version of Hive. An example of classes that should
be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need
to be shared are those that interact with classes that are already shared. For example,
@@ -1315,6 +1338,17 @@ the following case-insensitive options:
+
+
queryTimeout
+
+ The number of seconds the driver will wait for a Statement object to execute to the given
+ number of seconds. Zero means there is no limit. In the write path, this option depends on
+ how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver
+ checks the timeout of each query instead of an entire JDBC batch.
+ It defaults to 0.
+
+
+
fetchsize
@@ -1418,7 +1452,7 @@ SELECT * FROM resultTable
# Performance Tuning
-For some workloads it is possible to improve performance by either caching data in memory, or by
+For some workloads, it is possible to improve performance by either caching data in memory, or by
turning on some experimental options.
## Caching Data In Memory
@@ -1666,6 +1700,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da
`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set
the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default.
+In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically
+to non-Arrow optimization implementation if an error occurs before the actual computation within Spark.
+This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'.
+
{% include_example dataframe_with_arrow python/sql/arrow.py %}
@@ -1676,7 +1714,7 @@ Using the above optimizations with Arrow will produce the same results as when A
enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the
DataFrame to the driver program and should be done on a small subset of the data. Not all Spark
data types are currently supported and an error can be raised if a column has an unsupported type,
-see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`,
+see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`,
Spark will fall back to create the DataFrame without Arrow.
## Pandas UDFs (a.k.a. Vectorized UDFs)
@@ -1714,6 +1752,15 @@ To use `groupBy().apply()`, the user needs to define the following:
* A Python function that defines the computation for each group.
* A `StructType` object or a string that defines the schema of the output `DataFrame`.
+The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position,
+not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their
+position matches the corresponding field in the schema.
+
+Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column
+can differ from the order that it was placed in the dictionary. It is recommended in this case to
+explicitly define the column order using the `columns` keyword, e.g.
+`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`.
+
Note that all data for a group will be loaded into memory before the function is applied. This can
lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for
[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user
@@ -1734,7 +1781,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p
### Supported SQL Types
-Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`,
+Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`,
`ArrayType` of `TimestampType`, and nested `StructType`.
### Setting Arrow Batch Size
@@ -1774,6 +1821,25 @@ working with timestamps in `pandas_udf`s to get the best performance, see
# Migration Guide
+## Upgrading From Spark SQL 2.3 to 2.4
+
+ - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively.
+ - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`.
+ - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe.
+ - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``.
+ - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema.
+ - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0.
+ - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
+ - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
+ - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround.
+ - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files.
+ - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior.
+ - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`.
+
+## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above
+
+ - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production.
+
## Upgrading From Spark SQL 2.2 to 2.3
- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
@@ -1929,10 +1995,13 @@ working with timestamps in `pandas_udf`s to get the best performance, see
- The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`).
- Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them.
- The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible.
+ - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone.
## Upgrading From Spark SQL 2.1 to 2.2
- - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
+ - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
+
+ - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty).
## Upgrading From Spark SQL 2.0 to 2.1
@@ -1973,7 +2042,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
## Upgrading From Spark SQL 1.5 to 1.6
- - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
+ - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
connection owns a copy of their own SQL configuration and temporary function registry. Cached
tables are still shared though. If you prefer to run the Thrift server in the old single-session
mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add
@@ -2121,7 +2190,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe
directly, but instead provide most of the functionality that RDDs provide though their own
implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method.
-In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for
+In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for
some use cases. It is still recommended that users update their code to use `DataFrame` instead.
Java and Python users will need to update their code.
@@ -2130,11 +2199,11 @@ Java and Python users will need to update their code.
Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`)
that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users
of either language should use `SQLContext` and `DataFrame`. In general these classes try to
-use types that are usable from both languages (i.e. `Array` instead of language specific collections).
+use types that are usable from both languages (i.e. `Array` instead of language-specific collections).
In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading
is used instead.
-Additionally the Java specific types API has been removed. Users of both Scala and Java should
+Additionally, the Java specific types API has been removed. Users of both Scala and Java should
use the classes present in `org.apache.spark.sql.types` to describe schema programmatically.
@@ -2191,9 +2260,9 @@ referencing a singleton.
## Compatibility with Apache Hive
Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs.
-Currently Hive SerDes and UDFs are based on Hive 1.2.1,
+Currently, Hive SerDes and UDFs are based on Hive 1.2.1,
and Spark SQL can be connected to different versions of Hive Metastore
-(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)).
+(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)).
#### Deploying in Existing Hive Warehouses
@@ -2283,10 +2352,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su
less important due to Spark SQL's in-memory computational model. Others are slotted for future
releases of Spark SQL.
-* Block level bitmap indexes and virtual columns (used to build indexes)
-* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you
+* Block-level bitmap indexes and virtual columns (used to build indexes)
+* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you
need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`".
-* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still
+* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still
launches tasks to compute the result.
* Skew data flag: Spark SQL does not follow the skew data flags in Hive.
* `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint.
@@ -2943,6 +3012,6 @@ does not exactly match standard floating point semantics.
Specifically:
- NaN = NaN returns true.
- - In aggregations all NaN values are grouped together.
+ - In aggregations, all NaN values are grouped together.
- NaN is treated as a normal value in join keys.
- NaN values go last when in ascending order, larger than any other numeric value.
diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md
index 1dd54719b21aa..dacaa3438d489 100644
--- a/docs/storage-openstack-swift.md
+++ b/docs/storage-openstack-swift.md
@@ -39,7 +39,7 @@ For example, for Maven support, add the following to the pom.xml fi
# Configuration Parameters
Create core-site.xml and place it inside Spark's conf directory.
-The main category of parameters that should be configured are the authentication parameters
+The main category of parameters that should be configured is the authentication parameters
required by Keystone.
The following table contains a list of Keystone mandatory parameters. PROVIDER can be
diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md
index 257a4f7d4f3ca..a1b6942ffe0a4 100644
--- a/docs/streaming-flume-integration.md
+++ b/docs/streaming-flume-integration.md
@@ -17,7 +17,7 @@ Choose a machine in your cluster such that
- Flume can be configured to push data to a port on that machine.
-Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data.
+Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data.
#### Configuring Flume
Configure Flume agent to send data to an Avro sink by having the following in the configuration file.
@@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the
#### Configuring Flume
Configuring Flume on the chosen machine requires the following two steps.
-1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink .
+1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink.
(i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)).
@@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps.
agent.sinks.spark.port =
agent.sinks.spark.channel = memoryChannel
- Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink.
+ Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink.
See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about
configuring Flume agents.
diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md
index 9f0671da2ee31..becf217738d26 100644
--- a/docs/streaming-kafka-0-8-integration.md
+++ b/docs/streaming-kafka-0-8-integration.md
@@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The
## Approach 1: Receiver-based Approach
This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data.
-However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs.
+However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs.
Next, we discuss how to use this approach in your streaming application.
@@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application.
**Points to remember:**
- - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that.
+ - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that.
- Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers.
- - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use
+ - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use
`KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`).
3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications.
@@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e
- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune.
-- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka.
+- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka.
-- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information).
+- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information).
Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below).
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index ffda36d64a770..c30959263cdfa 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim
***
## DataFrame and SQL Operations
-You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
+You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
@@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following.
+ *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this
with Mesos.
-- *Configuring write ahead logs* - Since Spark 1.2,
- we have introduced _write ahead logs_ for achieving strong
+- *Configuring write-ahead logs* - Since Spark 1.2,
+ we have introduced _write-ahead logs_ for achieving strong
fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into
- a write ahead log in the configuration checkpoint directory. This prevents data loss on driver
+ a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver
recovery, thus ensuring zero data loss (discussed in detail in the
[Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting
the [configuration parameter](configuration.html#spark-streaming)
@@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following.
come at the cost of the receiving throughput of individual receivers. This can be corrected by
running [more receivers in parallel](#level-of-parallelism-in-data-receiving)
to increase aggregate throughput. Additionally, it is recommended that the replication of the
- received data within Spark be disabled when the write ahead log is enabled as the log is already
+ received data within Spark be disabled when the write-ahead log is enabled as the log is already
stored in a replicated storage system. This can be done by setting the storage level for the
input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that
- does not support flushing) for _write ahead logs_, please remember to enable
+ does not support flushing) for _write-ahead logs_, please remember to enable
`spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and
`spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See
[Spark Streaming Configuration](configuration.html#spark-streaming) for more details.
- Note that Spark will not encrypt data written to the write ahead log when I/O encryption is
- enabled. If encryption of the write ahead log data is desired, it should be stored in a file
+ Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is
+ enabled. If encryption of the write-ahead log data is desired, it should be stored in a file
system that supports encryption natively.
- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming
@@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality
- Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued.
-- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted.
+- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted.
-- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited.
+- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited.
***************************************************************************************************
@@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated
lost. This will affect the results of the stateful transformations.
To avoid this loss of past received data, Spark 1.2 introduced _write
-ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs
+ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs
enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee.
The following table summarizes the semantics under failures:
@@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
Spark 1.1 or earlier, OR
- Spark 1.2 or later without write ahead logs
+ Spark 1.2 or later without write-ahead logs
Buffered data lost with unreliable receivers
@@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
-
Spark 1.2 or later with write ahead logs
+
Spark 1.2 or later with write-ahead logs
Zero data loss with reliable receivers
At-least once semantics
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index 5647ec6bc5797..71fd5b10cc407 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
For Python applications, you need to add this above library and its dependencies when deploying your
application. See the [Deploying](#deploying) subsection below.
-For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below.
+For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below.
## Reading Data from Kafka
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 48d6d0b542cc0..0842e8dd88672 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide
{:toc}
# Overview
-Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.*
+Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.*
Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements.
@@ -479,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section.
## Fault Tolerance Semantics
Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers)
-to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure.
+to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure.
# API using Datasets and DataFrames
Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession`
@@ -690,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the
By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`.
-Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`).
+Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`).
## Operations on streaming DataFrames/Datasets
You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use.
@@ -904,7 +904,7 @@ windowedCounts <- count(
-### Handling Late Data and Watermarking
+#### Handling Late Data and Watermarking
Now consider what happens if one of the events arrives late to the application.
For example, say, a word generated at 12:04 (i.e. event time) could be received by
the application at 12:11. The application should use the time 12:04 instead of 12:11
@@ -925,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec
event time. For a specific window starting at time `T`, the engine will maintain state and allow late
data to update the state until `(max event time seen by the engine - late threshold > T)`.
In other words, late data within the threshold will be aggregated,
-but data later than the threshold will be dropped. Let's understand this with an example. We can
+but data later than the threshold will start getting dropped
+(see [later](#semantic-guarantees-of-aggregation-with-watermarking)
+in the section for the exact guarantees). Let's understand this with an example. We can
easily define watermarking on the previous example using `withWatermark()` as shown below.
@@ -1031,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final
counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is
appended to the Result Table only after the watermark is updated to `12:11`.
-**Conditions for watermarking to clean aggregation state**
+##### Conditions for watermarking to clean aggregation state
+{:.no_toc}
+
It is important to note that the following conditions must be satisfied for the watermarking to
clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*.
@@ -1051,6 +1055,16 @@ from the aggregation column.
For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append
output mode.
+##### Semantic Guarantees of Aggregation with Watermarking
+{:.no_toc}
+
+- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never
+drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind
+(in terms of event-time) the latest data processed till then is guaranteed to be aggregated.
+
+- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is
+not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less
+likely is the engine going to process it.
### Join Operations
Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame
@@ -1062,7 +1076,7 @@ Dataset/DataFrame will be the exactly the same as if it was with a static Datase
containing the same data in the stream.
-#### Stream-static joins
+#### Stream-static Joins
Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some
type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
@@ -1269,6 +1283,12 @@ joined <- join(
+###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking
+{:.no_toc}
+This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking).
+A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than
+ 2 hours delayed. But data delayed by more than 2 hours may or may not get processed.
+
##### Outer Joins with Watermarking
While the watermark + event-time constraints is optional for inner joins, for left and right outer
joins they must be specified. This is because for generating the NULL results in outer join, the
@@ -1347,7 +1367,14 @@ joined <- join(
-There are a few points to note regarding outer joins.
+###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking
+{:.no_toc}
+Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking)
+regarding watermark delays and whether data will be dropped or not.
+
+###### Caveats
+{:.no_toc}
+There are a few important characteristics to note regarding how the outer results are generated.
- *The outer NULL results will be generated with a delay that depends on the specified watermark
delay and the time range condition.* This is because the engine has to wait for that long to ensure
@@ -1962,7 +1989,7 @@ head(sql("select * from aggregates"))
-#### Using Foreach
+##### Using Foreach
The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter`
([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs),
which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points.
@@ -1979,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated
- Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks.
+#### Triggers
+The trigger settings of a streaming query defines the timing of streaming data processing, whether
+the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query.
+Here are the different kinds of triggers that are supported.
+
+
+
+
Trigger Type
+
Description
+
+
+
unspecified (default)
+
+ If no trigger setting is explicitly specified, then by default, the query will be
+ executed in micro-batch mode, where micro-batches will be generated as soon as
+ the previous micro-batch has completed processing.
+
+
+
+
Fixed interval micro-batches
+
+ The query will be executed with micro-batches mode, where micro-batches will be kicked off
+ at the user-specified intervals.
+
+
If the previous micro-batch completes within the interval, then the engine will wait until
+ the interval is over before kicking off the next micro-batch.
+
+
If the previous micro-batch takes longer than the interval to complete (i.e. if an
+ interval boundary is missed), then the next micro-batch will start as soon as the
+ previous one completes (i.e., it will not wait for the next interval boundary).
+
+
If no new data is available, then no micro-batch will be kicked off.
+
+
+
+
+
One-time micro-batch
+
+ The query will execute *only one* micro-batch to process all the available data and then
+ stop on its own. This is useful in scenarios you want to periodically spin up a cluster,
+ process everything that is available since the last period, and then shutdown the
+ cluster. In some case, this may lead to significant cost savings.
+
+
+
+
Continuous with fixed checkpoint interval (experimental)
+
+ The query will be executed in the new low-latency, continuous processing mode. Read more
+ about this in the Continuous Processing section below.
+
+
+{% highlight r %}
+# Default trigger (runs micro-batch as soon as it can)
+write.stream(df, "console")
+
+# ProcessingTime trigger with two-seconds micro-batch interval
+write.stream(df, "console", trigger.processingTime = "2 seconds")
+
+# One-time trigger
+write.stream(df, "console", trigger.once = TRUE)
+
+# Continuous trigger is not yet supported
+{% endhighlight %}
+
+
+
+
## Managing Streaming Queries
The `StreamingQuery` object created when a query is started can be used to monitor and manage the query.
@@ -2468,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true")
All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.).
## Recovering from Failures with Checkpointing
-In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
+In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
-# Continuous Processing [Experimental]
+# Continuous Processing
+## [Experimental]
+{:.no_toc}
+
**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations).
To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example,
@@ -2589,6 +2785,8 @@ spark \
A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees.
## Supported Queries
+{:.no_toc}
+
As of Spark 2.3, only the following type of queries are supported in the continuous processing mode.
- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.).
@@ -2606,6 +2804,8 @@ As of Spark 2.3, only the following type of queries are supported in the continu
See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic.
## Caveats
+{:.no_toc}
+
- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress.
- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored.
- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint.
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index a3643bf0838a1..77aa083c4a584 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats:
# Loading Configuration from a File
The `spark-submit` script can load default [Spark configuration values](configuration.html) from a
-properties file and pass them on to your application. By default it will read options
+properties file and pass them on to your application. By default, it will read options
from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on
[loading default configurations](configuration.html#loading-default-configurations).
diff --git a/docs/tuning.md b/docs/tuning.md
index fc27713f28d46..1c3bd0e8758ff 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -132,7 +132,7 @@ The best way to size the amount of memory consumption a dataset will require is
into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD
is occupying.
-To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method
+To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method.
This is useful for experimenting with different data layouts to trim memory usage, as well as
determining the amount of space a broadcast variable will occupy on each executor heap.
@@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor
* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old
- enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked.
+ enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked.
The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that
the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java
new file mode 100644
index 0000000000000..51865637df6f6
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+// $example on$
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.ml.clustering.PowerIterationClustering;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+public class JavaPowerIterationClusteringExample {
+ public static void main(String[] args) {
+ // Create a SparkSession.
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaPowerIterationClustering")
+ .getOrCreate();
+
+ // $example on$
+ List data = Arrays.asList(
+ RowFactory.create(0L, 1L, 1.0),
+ RowFactory.create(0L, 2L, 1.0),
+ RowFactory.create(1L, 2L, 1.0),
+ RowFactory.create(3L, 4L, 1.0),
+ RowFactory.create(4L, 0L, 0.1)
+ );
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("src", DataTypes.LongType, false, Metadata.empty()),
+ new StructField("dst", DataTypes.LongType, false, Metadata.empty()),
+ new StructField("weight", DataTypes.DoubleType, false, Metadata.empty())
+ });
+
+ Dataset df = spark.createDataFrame(data, schema);
+
+ PowerIterationClustering model = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(10)
+ .setInitMode("degree")
+ .setWeightCol("weight");
+
+ Dataset result = model.assignClusters(df);
+ result.show(false);
+ // $example off$
+ spark.stop();
+ }
+}
diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py
index 6286ba6541fbd..a18722c687f8b 100644
--- a/examples/src/main/python/avro_inputformat.py
+++ b/examples/src/main/python/avro_inputformat.py
@@ -61,7 +61,7 @@
Assumes you have Avro data stored in . Reader schema can be optionally specified
in [reader_schema_file].
""", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
path = sys.argv[1]
diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py
index 92e0a3ae2ee60..a42d711fc505f 100755
--- a/examples/src/main/python/kmeans.py
+++ b/examples/src/main/python/kmeans.py
@@ -49,7 +49,7 @@ def closestPoint(p, centers):
if len(sys.argv) != 4:
print("Usage: kmeans ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
print("""WARN: This is a naive implementation of KMeans Clustering and is given
as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an
diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py
index 01c938454b108..bcc4e0f4e8eae 100755
--- a/examples/src/main/python/logistic_regression.py
+++ b/examples/src/main/python/logistic_regression.py
@@ -48,7 +48,7 @@ def readPointBatch(iterator):
if len(sys.argv) != 3:
print("Usage: logistic_regression ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
print("""WARN: This is a naive implementation of Logistic Regression and is
given as an example!
diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py
index 109f901012c9c..cabc3de68f2f4 100644
--- a/examples/src/main/python/ml/dataframe_example.py
+++ b/examples/src/main/python/ml/dataframe_example.py
@@ -17,7 +17,7 @@
"""
An example of how to use DataFrame for ML. Run with::
- bin/spark-submit examples/src/main/python/ml/dataframe_example.py
+ bin/spark-submit examples/src/main/python/ml/dataframe_example.py
"""
from __future__ import print_function
@@ -33,20 +33,20 @@
if __name__ == "__main__":
if len(sys.argv) > 2:
print("Usage: dataframe_example.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
elif len(sys.argv) == 2:
- input = sys.argv[1]
+ input_path = sys.argv[1]
else:
- input = "data/mllib/sample_libsvm_data.txt"
+ input_path = "data/mllib/sample_libsvm_data.txt"
spark = SparkSession \
.builder \
.appName("DataFrameExample") \
.getOrCreate()
- # Load input data
- print("Loading LIBSVM file with UDT from " + input + ".")
- df = spark.read.format("libsvm").load(input).cache()
+ # Load an input file
+ print("Loading LIBSVM file with UDT from " + input_path + ".")
+ df = spark.read.format("libsvm").load(input_path).cache()
print("Schema from LIBSVM:")
df.printSchema()
print("Loaded training data as a DataFrame with " +
diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py
index 0e13546b88e67..089504fa7064b 100755
--- a/examples/src/main/python/mllib/correlations.py
+++ b/examples/src/main/python/mllib/correlations.py
@@ -31,7 +31,7 @@
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print("Usage: correlations ()", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonCorrelations")
if len(sys.argv) == 2:
filepath = sys.argv[1]
diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py
index 002fc75799648..1bdb3e9b4a2af 100755
--- a/examples/src/main/python/mllib/kmeans.py
+++ b/examples/src/main/python/mllib/kmeans.py
@@ -36,7 +36,7 @@ def parseVector(line):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: kmeans ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="KMeans")
lines = sc.textFile(sys.argv[1])
data = lines.map(parseVector)
diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py
index d4f1d34e2d8cf..87efe17375226 100755
--- a/examples/src/main/python/mllib/logistic_regression.py
+++ b/examples/src/main/python/mllib/logistic_regression.py
@@ -42,7 +42,7 @@ def parsePoint(line):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: logistic_regression ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonLR")
points = sc.textFile(sys.argv[1]).map(parsePoint)
iterations = int(sys.argv[2])
diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py
index 729bae30b152c..9a429b5f8abdf 100755
--- a/examples/src/main/python/mllib/random_rdd_generation.py
+++ b/examples/src/main/python/mllib/random_rdd_generation.py
@@ -29,7 +29,7 @@
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print("Usage: random_rdd_generation", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonRandomRDDGeneration")
diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py
index b7033ab7daeb3..00e7cf4bbcdbf 100755
--- a/examples/src/main/python/mllib/sampled_rdds.py
+++ b/examples/src/main/python/mllib/sampled_rdds.py
@@ -29,7 +29,7 @@
if __name__ == "__main__":
if len(sys.argv) not in [1, 2]:
print("Usage: sampled_rdds ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
if len(sys.argv) == 2:
datapath = sys.argv[1]
else:
@@ -43,7 +43,7 @@
numExamples = examples.count()
if numExamples == 0:
print("Error: Data file had no samples to load.", file=sys.stderr)
- exit(1)
+ sys.exit(1)
print('Loaded data with %d examples from file: %s' % (numExamples, datapath))
# Example: RDD.sample() and RDD.takeSample()
diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py
index f600496867c11..714c9a0de7217 100644
--- a/examples/src/main/python/mllib/streaming_linear_regression_example.py
+++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py
@@ -36,7 +36,7 @@
if len(sys.argv) != 3:
print("Usage: streaming_linear_regression_example.py ",
file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py
index 0d6c253d397a0..2c19e8700ab16 100755
--- a/examples/src/main/python/pagerank.py
+++ b/examples/src/main/python/pagerank.py
@@ -47,7 +47,7 @@ def parseNeighbors(urls):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: pagerank ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
print("WARN: This is a naive implementation of PageRank and is given as an example!\n" +
"Please refer to PageRank implementation provided by graphx",
diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py
index a3f86cf8999cf..83041f0040a0c 100644
--- a/examples/src/main/python/parquet_inputformat.py
+++ b/examples/src/main/python/parquet_inputformat.py
@@ -45,7 +45,7 @@
/path/to/examples/parquet_inputformat.py
Assumes you have Parquet data stored in .
""", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
path = sys.argv[1]
diff --git a/examples/src/main/python/py_container_checks.py b/examples/src/main/python/py_container_checks.py
new file mode 100644
index 0000000000000..f6b3be2806c82
--- /dev/null
+++ b/examples/src/main/python/py_container_checks.py
@@ -0,0 +1,32 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+
+
+def version_check(python_env, major_python_version):
+ """
+ These are various tests to test the Python container image.
+ This file will be distributed via --py-files in the e2e tests.
+ """
+ env_version = os.environ.get('PYSPARK_PYTHON')
+ print("Python runtime version check is: " +
+ str(sys.version_info[0] == major_python_version))
+
+ print("Python environment version check is: " +
+ str(env_version == python_env))
diff --git a/examples/src/main/python/pyfiles.py b/examples/src/main/python/pyfiles.py
new file mode 100644
index 0000000000000..4193654b49a12
--- /dev/null
+++ b/examples/src/main/python/pyfiles.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+
+from pyspark.sql import SparkSession
+
+
+if __name__ == "__main__":
+ """
+ Usage: pyfiles [major_python_version]
+ """
+ spark = SparkSession \
+ .builder \
+ .appName("PyFilesTest") \
+ .getOrCreate()
+
+ from py_container_checks import version_check
+ # Begin of Python container checks
+ version_check(sys.argv[1], 2 if sys.argv[1] == "python" else 3)
+
+ spark.stop()
diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py
index 81898cf6d5ce6..d3cd985d197e3 100755
--- a/examples/src/main/python/sort.py
+++ b/examples/src/main/python/sort.py
@@ -25,7 +25,7 @@
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: sort ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
spark = SparkSession\
.builder\
diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py
index 9e8a552b3b10b..921067891352a 100644
--- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py
+++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py
@@ -49,7 +49,7 @@
print("""
Usage: structured_kafka_wordcount.py
""", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
bootstrapServers = sys.argv[1]
subscribeType = sys.argv[2]
diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py
index c3284c1d01017..9ac392164735b 100644
--- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py
+++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py
@@ -38,7 +38,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: structured_network_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
host = sys.argv[1]
port = int(sys.argv[2])
diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py
index db672551504b5..c4e3bbf44cd5a 100644
--- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py
+++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py
@@ -53,7 +53,7 @@
msg = ("Usage: structured_network_wordcount_windowed.py "
" []")
print(msg, file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
host = sys.argv[1]
port = int(sys.argv[2])
diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py
index 425df309011a0..c5c186c11f79a 100644
--- a/examples/src/main/python/streaming/direct_kafka_wordcount.py
+++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py
@@ -39,7 +39,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: direct_kafka_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount")
ssc = StreamingContext(sc, 2)
diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py
index 5d6e6dc36d6f9..c8ea92b61ca6e 100644
--- a/examples/src/main/python/streaming/flume_wordcount.py
+++ b/examples/src/main/python/streaming/flume_wordcount.py
@@ -39,7 +39,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: flume_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingFlumeWordCount")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
index f815dd26823d1..f9a5c43a8eaa9 100644
--- a/examples/src/main/python/streaming/hdfs_wordcount.py
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -35,7 +35,7 @@
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: hdfs_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingHDFSWordCount")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py
index 704f6602e2297..e9ee08b9fd228 100644
--- a/examples/src/main/python/streaming/kafka_wordcount.py
+++ b/examples/src/main/python/streaming/kafka_wordcount.py
@@ -39,7 +39,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: kafka_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingKafkaWordCount")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py
index 9010fafb425e6..f3099d2517cd5 100644
--- a/examples/src/main/python/streaming/network_wordcount.py
+++ b/examples/src/main/python/streaming/network_wordcount.py
@@ -35,7 +35,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: network_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py
index d51a380a5d5f9..2b5434c0c845a 100644
--- a/examples/src/main/python/streaming/network_wordjoinsentiments.py
+++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py
@@ -47,7 +47,7 @@ def print_happiest_words(rdd):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: network_wordjoinsentiments.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments")
ssc = StreamingContext(sc, 5)
diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py
index 52b2639cdf55c..60167dc772544 100644
--- a/examples/src/main/python/streaming/recoverable_network_wordcount.py
+++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py
@@ -101,7 +101,7 @@ def filterFunc(wordCount):
if len(sys.argv) != 5:
print("Usage: recoverable_network_wordcount.py "
"", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
host, port, checkpoint, output = sys.argv[1:]
ssc = StreamingContext.getOrCreate(checkpoint,
lambda: createContext(host, int(port), output))
diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py
index 7f12281c0e3fe..ab3cfc067994d 100644
--- a/examples/src/main/python/streaming/sql_network_wordcount.py
+++ b/examples/src/main/python/streaming/sql_network_wordcount.py
@@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: sql_network_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
host, port = sys.argv[1:]
sc = SparkContext(appName="PythonSqlNetworkWordCount")
ssc = StreamingContext(sc, 1)
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
index d7bb61e729f18..d5d1eba6c5969 100644
--- a/examples/src/main/python/streaming/stateful_network_wordcount.py
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -39,7 +39,7 @@
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: stateful_network_wordcount.py ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
ssc = StreamingContext(sc, 1)
ssc.checkpoint("checkpoint")
diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py
index 3d5e44d5b2df1..a05e24ff3ff95 100755
--- a/examples/src/main/python/wordcount.py
+++ b/examples/src/main/python/wordcount.py
@@ -26,7 +26,7 @@
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: wordcount ", file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
spark = SparkSession\
.builder\
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala
new file mode 100644
index 0000000000000..64076f2deb706
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples
+
+import java.io.File
+
+import org.apache.spark.SparkFiles
+import org.apache.spark.sql.SparkSession
+
+/** Usage: SparkRemoteFileTest [file] */
+object SparkRemoteFileTest {
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ System.err.println("Usage: SparkRemoteFileTest ")
+ System.exit(1)
+ }
+ val spark = SparkSession
+ .builder()
+ .appName("SparkRemoteFileTest")
+ .getOrCreate()
+ val sc = spark.sparkContext
+ val rdd = sc.parallelize(Seq(1)).map(_ => {
+ val localLocation = SparkFiles.get(args(0))
+ println(s"${args(0)} is stored at: $localLocation")
+ new File(localLocation).isFile
+ })
+ val truthCheck = rdd.collect().head
+ println(s"Mounting of ${args(0)} was $truthCheck")
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala
new file mode 100644
index 0000000000000..ca8f7affb14e8
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.clustering.PowerIterationClustering
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+object PowerIterationClusteringExample {
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+
+ // $example on$
+ val dataset = spark.createDataFrame(Seq(
+ (0L, 1L, 1.0),
+ (0L, 2L, 1.0),
+ (1L, 2L, 1.0),
+ (3L, 4L, 1.0),
+ (4L, 0L, 0.1)
+ )).toDF("src", "dst", "weight")
+
+ val model = new PowerIterationClustering().
+ setK(2).
+ setMaxIter(20).
+ setInitMode("degree").
+ setWeightCol("weight")
+
+ val prediction = model.assignClusters(dataset).select("id", "cluster")
+
+ // Shows the cluster assignment
+ prediction.show(false)
+ // $example off$
+
+ spark.stop()
+ }
+ }
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
index def06026bde96..2082fb71afdf1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -18,6 +18,9 @@
// scalastyle:off println
package org.apache.spark.examples.streaming
+import org.apache.kafka.clients.consumer.ConsumerConfig
+import org.apache.kafka.common.serialization.StringDeserializer
+
import org.apache.spark.SparkConf
import org.apache.spark.streaming._
import org.apache.spark.streaming.kafka010._
@@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._
* Consumes messages from one or more topics in Kafka and does wordcount.
* Usage: DirectKafkaWordCount
* is a list of one or more Kafka brokers
+ * is a consumer group name to consume from topics
* is a list of one or more kafka topics to consume from
*
* Example:
* $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \
- * topic1,topic2
+ * consumer-group topic1,topic2
*/
object DirectKafkaWordCount {
def main(args: Array[String]) {
- if (args.length < 2) {
+ if (args.length < 3) {
System.err.println(s"""
|Usage: DirectKafkaWordCount
| is a list of one or more Kafka brokers
+ | is a consumer group name to consume from topics
| is a list of one or more kafka topics to consume from
|
""".stripMargin)
@@ -46,7 +51,7 @@ object DirectKafkaWordCount {
StreamingExamples.setStreamingLogLevels()
- val Array(brokers, topics) = args
+ val Array(brokers, groupId, topics) = args
// Create context with 2 second batch interval
val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
@@ -54,7 +59,11 @@ object DirectKafkaWordCount {
// Create direct kafka stream with brokers and topics
val topicsSet = topics.split(",").toSet
- val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
+ val kafkaParams = Map[String, Object](
+ ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers,
+ ConsumerConfig.GROUP_ID_CONFIG -> groupId,
+ ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer],
+ ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer])
val messages = KafkaUtils.createDirectStream[String, String](
ssc,
LocationStrategies.PreferConsistent,
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index b049a054cb40e..badaa69cc303c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -27,13 +27,10 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.types.UTF8String
/**
* A [[ContinuousReader]] for data from kafka.
@@ -66,7 +63,7 @@ class KafkaContinuousReader(
// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
// Exposed outside this object only for unit tests.
- private[sql] var knownPartitions: Set[TopicPartition] = _
+ @volatile private[sql] var knownPartitions: Set[TopicPartition] = _
override def readSchema: StructType = KafkaOffsetReader.kafkaSchema
@@ -89,7 +86,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}
- override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
import scala.collection.JavaConverters._
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
@@ -109,9 +106,9 @@ class KafkaContinuousReader(
startOffsets.toSeq.map {
case (topicPartition, start) =>
- KafkaContinuousDataReaderFactory(
+ KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
- .asInstanceOf[DataReaderFactory[UnsafeRow]]
+ .asInstanceOf[InputPartition[UnsafeRow]]
}.asJava
}
@@ -149,7 +146,7 @@ class KafkaContinuousReader(
}
/**
- * A data reader factory for continuous Kafka processing. This will be serialized and transformed
+ * An input partition for continuous Kafka processing. This will be serialized and transformed
* into a full reader on executors.
*
* @param topicPartition The (topic, partition) pair this task is responsible for.
@@ -159,14 +156,23 @@ class KafkaContinuousReader(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
-case class KafkaContinuousDataReaderFactory(
+case class KafkaContinuousInputPartition(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] {
- override def createDataReader(): KafkaContinuousDataReader = {
- new KafkaContinuousDataReader(
+ failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] {
+
+ override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = {
+ val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
+ require(kafkaOffset.topicPartition == topicPartition,
+ s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
+ new KafkaContinuousInputPartitionReader(
+ topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
+ }
+
+ override def createPartitionReader(): KafkaContinuousInputPartitionReader = {
+ new KafkaContinuousInputPartitionReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
}
@@ -181,19 +187,14 @@ case class KafkaContinuousDataReaderFactory(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
-class KafkaContinuousDataReader(
+class KafkaContinuousInputPartitionReader(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
- failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
- private val topic = topicPartition.topic
- private val kafkaPartition = topicPartition.partition
- private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams)
-
- private val sharedRow = new UnsafeRow(7)
- private val bufferHolder = new BufferHolder(sharedRow)
- private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
+ failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] {
+ private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)
+ private val converter = new KafkaRecordToUnsafeRowConverter
private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
@@ -232,22 +233,7 @@ class KafkaContinuousDataReader(
}
override def get(): UnsafeRow = {
- bufferHolder.reset()
-
- if (currentRecord.key == null) {
- rowWriter.setNullAt(0)
- } else {
- rowWriter.write(0, currentRecord.key)
- }
- rowWriter.write(1, currentRecord.value)
- rowWriter.write(2, UTF8String.fromString(currentRecord.topic))
- rowWriter.write(3, currentRecord.partition)
- rowWriter.write(4, currentRecord.offset)
- rowWriter.write(5,
- DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp)))
- rowWriter.write(6, currentRecord.timestampType.id)
- sharedRow.setTotalSize(bufferHolder.totalSize)
- sharedRow
+ converter.toUnsafeRow(currentRecord)
}
override def getOffset(): KafkaSourcePartitionOffset = {
@@ -255,6 +241,6 @@ class KafkaContinuousDataReader(
}
override def close(): Unit = {
- consumer.close()
+ consumer.release()
}
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
similarity index 65%
rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index 90ed7b1fba2f8..941f0ab177e48 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.kafka010.KafkaSource._
+import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange
+import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.util.UninterruptibleThread
+private[kafka010] sealed trait KafkaDataConsumer {
+ /**
+ * Get the record for the given offset if available. Otherwise it will either throw error
+ * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset),
+ * or null.
+ *
+ * @param offset the offset to fetch.
+ * @param untilOffset the max offset to fetch. Exclusive.
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at
+ * offset if available, or throw exception.when `failOnDataLoss` is `false`,
+ * this method will either return record at offset if available, or return
+ * the next earliest available record less than untilOffset, or null. It
+ * will not throw any exception.
+ */
+ def get(
+ offset: Long,
+ untilOffset: Long,
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
+ internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss)
+ }
+
+ /**
+ * Return the available offset range of the current partition. It's a pair of the earliest offset
+ * and the latest offset.
+ */
+ def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange()
+
+ /**
+ * Release this consumer from being further used. Depending on its implementation,
+ * this consumer will be either finalized, or reset for reuse later.
+ */
+ def release(): Unit
+
+ /** Reference to the internal implementation that this wrapper delegates to */
+ protected def internalConsumer: InternalKafkaConsumer
+}
+
/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is usually bad anyway.
+ * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected.
+ * This is not for direct use outside this file.
*/
-private[kafka010] case class CachedKafkaConsumer private(
+private[kafka010] case class InternalKafkaConsumer(
topicPartition: TopicPartition,
kafkaParams: ju.Map[String, Object]) extends Logging {
- import CachedKafkaConsumer._
+ import InternalKafkaConsumer._
private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
- private var consumer = createConsumer
+ @volatile private var consumer = createConsumer
/** indicates whether this consumer is in use or not */
- private var inuse = true
+ @volatile var inUse = true
+
+ /** indicate whether this consumer is going to be stopped in the next release */
+ @volatile var markedForClose = false
/** Iterator to the already fetch data */
- private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
- private var nextOffsetInFetchedData = UNKNOWN_OFFSET
+ @volatile private var fetchedData =
+ ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
+ @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET
/** Create a KafkaConsumer to fetch records for `topicPartition` */
private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = {
@@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private(
c
}
- case class AvailableOffsetRange(earliest: Long, latest: Long)
-
private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match {
case ut: UninterruptibleThread =>
ut.runUninterruptibly(body)
@@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private(
}
}
-private[kafka010] object CachedKafkaConsumer extends Logging {
- private val UNKNOWN_OFFSET = -2L
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+ case class AvailableOffsetRange(earliest: Long, latest: Long)
+
+ private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer)
+ extends KafkaDataConsumer {
+ assert(internalConsumer.inUse) // make sure this has been set to true
+ override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) }
+ }
+
+ private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer)
+ extends KafkaDataConsumer {
+ override def release(): Unit = { internalConsumer.close() }
+ }
- private case class CacheKey(groupId: String, topicPartition: TopicPartition)
+ private case class CacheKey(groupId: String, topicPartition: TopicPartition) {
+ def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) =
+ this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition)
+ }
+ // This cache has the following important properties.
+ // - We make a best-effort attempt to maintain the max size of the cache as configured capacity.
+ // The capacity is not guaranteed to be maintained, especially when there are more active
+ // tasks simultaneously using consumers than the capacity.
private lazy val cache = {
val conf = SparkEnv.get.conf
val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64)
- new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) {
+ new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) {
override def removeEldestEntry(
- entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = {
- if (entry.getValue.inuse == false && this.size > capacity) {
- logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " +
- s"removing consumer for ${entry.getKey}")
+ entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = {
+
+ // Try to remove the least-used entry if its currently not in use.
+ //
+ // If you cannot remove it, then the cache will keep growing. In the worst case,
+ // the cache will grow to the max number of concurrent tasks that can run in the executor,
+ // (that is, number of tasks slots) after which it will never reduce. This is unlikely to
+ // be a serious problem because an executor with more than 64 (default) tasks slots is
+ // likely running on a beefy machine that can handle a large number of simultaneously
+ // active consumers.
+
+ if (!entry.getValue.inUse && this.size > capacity) {
+ logWarning(
+ s"KafkaConsumer cache hitting max capacity of $capacity, " +
+ s"removing consumer for ${entry.getKey}")
try {
entry.getValue.close()
} catch {
@@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
}
}
- def releaseKafkaConsumer(
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): Unit = {
- val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
- val topicPartition = new TopicPartition(topic, partition)
- val key = CacheKey(groupId, topicPartition)
-
- synchronized {
- val consumer = cache.get(key)
- if (consumer != null) {
- consumer.inuse = false
- } else {
- logWarning(s"Attempting to release consumer that does not exist")
- }
- }
- }
-
/**
- * Removes (and closes) the Kafka Consumer for the given topic, partition and group id.
+ * Get a cached consumer for groupId, assigned to topic and partition.
+ * If matching consumer doesn't already exist, will be created using kafkaParams.
+ * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]].
+ *
+ * Note: This method guarantees that the consumer returned is not currently in use by any one
+ * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by
+ * caching them and tracking when they are in use.
*/
- def removeKafkaConsumer(
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): Unit = {
- val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
- val topicPartition = new TopicPartition(topic, partition)
- val key = CacheKey(groupId, topicPartition)
+ def acquire(
+ topicPartition: TopicPartition,
+ kafkaParams: ju.Map[String, Object],
+ useCache: Boolean): KafkaDataConsumer = synchronized {
+ val key = new CacheKey(topicPartition, kafkaParams)
+ val existingInternalConsumer = cache.get(key)
- synchronized {
- val removedConsumer = cache.remove(key)
- if (removedConsumer != null) {
- removedConsumer.close()
+ lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams)
+
+ if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
+ // If this is reattempt at running the task, then invalidate cached consumer if any and
+ // start with a new one.
+ if (existingInternalConsumer != null) {
+ // Consumer exists in cache. If its in use, mark it for closing later, or close it now.
+ if (existingInternalConsumer.inUse) {
+ existingInternalConsumer.markedForClose = true
+ } else {
+ existingInternalConsumer.close()
+ }
}
+ cache.remove(key) // Invalidate the cache in any case
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+
+ } else if (!useCache) {
+ // If planner asks to not reuse consumers, then do not use it, return a new consumer
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+
+ } else if (existingInternalConsumer == null) {
+ // If consumer is not already cached, then put a new in the cache and return it
+ cache.put(key, newInternalConsumer)
+ newInternalConsumer.inUse = true
+ CachedKafkaDataConsumer(newInternalConsumer)
+
+ } else if (existingInternalConsumer.inUse) {
+ // If consumer is already cached but is currently in use, then return a new consumer
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+
+ } else {
+ // If consumer is already cached and is currently not in use, then return that consumer
+ existingInternalConsumer.inUse = true
+ CachedKafkaDataConsumer(existingInternalConsumer)
}
}
- /**
- * Get a cached consumer for groupId, assigned to topic and partition.
- * If matching consumer doesn't already exist, will be created using kafkaParams.
- */
- def getOrCreate(
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized {
- val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
- val topicPartition = new TopicPartition(topic, partition)
- val key = CacheKey(groupId, topicPartition)
-
- // If this is reattempt at running the task, then invalidate cache and start with
- // a new consumer
- if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
- removeKafkaConsumer(topic, partition, kafkaParams)
- val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
- consumer.inuse = true
- cache.put(key, consumer)
- consumer
- } else {
- if (!cache.containsKey(key)) {
- cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams))
+ private def release(intConsumer: InternalKafkaConsumer): Unit = {
+ synchronized {
+
+ // Clear the consumer from the cache if this is indeed the consumer present in the cache
+ val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams)
+ val cachedIntConsumer = cache.get(key)
+ if (intConsumer.eq(cachedIntConsumer)) {
+ // The released consumer is the same object as the cached one.
+ if (intConsumer.markedForClose) {
+ intConsumer.close()
+ cache.remove(key)
+ } else {
+ intConsumer.inUse = false
+ }
+ } else {
+ // The released consumer is either not the same one as in the cache, or not in the cache
+ // at all. This may happen if the cache was invalidate while this consumer was being used.
+ // Just close this consumer.
+ intConsumer.close()
+ logInfo(s"Released a supposedly cached consumer that was not found in the cache")
}
- val consumer = cache.get(key)
- consumer.inuse = true
- consumer
}
}
+}
- /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */
- def createUncached(
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = {
- new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams)
- }
+private[kafka010] object InternalKafkaConsumer extends Logging {
+
+ private val UNKNOWN_OFFSET = -2L
private def reportDataLoss0(
failOnDataLoss: Boolean,
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
new file mode 100644
index 0000000000000..737da2e51b125
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
@@ -0,0 +1,381 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+import java.io._
+import java.nio.charset.StandardCharsets
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.UninterruptibleThread
+
+/**
+ * A [[MicroBatchReader]] that reads data from Kafka.
+ *
+ * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains
+ * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For
+ * example if the last record in a Kafka topic "t", partition 2 is offset 5, then
+ * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent
+ * with the semantics of `KafkaConsumer.position()`.
+ *
+ * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user
+ * must make sure all messages in a topic have been processed when deleting a topic.
+ *
+ * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped.
+ * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
+ * and not use wrong broker addresses.
+ */
+private[kafka010] class KafkaMicroBatchReader(
+ kafkaOffsetReader: KafkaOffsetReader,
+ executorKafkaParams: ju.Map[String, Object],
+ options: DataSourceOptions,
+ metadataPath: String,
+ startingOffsets: KafkaOffsetRangeLimit,
+ failOnDataLoss: Boolean)
+ extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
+
+ private var startPartitionOffsets: PartitionOffsetMap = _
+ private var endPartitionOffsets: PartitionOffsetMap = _
+
+ private val pollTimeoutMs = options.getLong(
+ "kafkaConsumer.pollTimeoutMs",
+ SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L)
+
+ private val maxOffsetsPerTrigger =
+ Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong)
+
+ private val rangeCalculator = KafkaOffsetRangeCalculator(options)
+ /**
+ * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
+ * called in StreamExecutionThread. Otherwise, interrupting a thread while running
+ * `KafkaConsumer.poll` may hang forever (KAFKA-1894).
+ */
+ private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets()
+
+ override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = {
+ // Make sure initialPartitionOffsets is initialized
+ initialPartitionOffsets
+
+ startPartitionOffsets = Option(start.orElse(null))
+ .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
+ .getOrElse(initialPartitionOffsets)
+
+ endPartitionOffsets = Option(end.orElse(null))
+ .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
+ .getOrElse {
+ val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets()
+ maxOffsetsPerTrigger.map { maxOffsets =>
+ rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets)
+ }.getOrElse {
+ latestPartitionOffsets
+ }
+ }
+ }
+
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
+ // Find the new partitions, and get their earliest offsets
+ val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
+ val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+ if (newPartitionInitialOffsets.keySet != newPartitions) {
+ // We cannot get from offsets for some partitions. It means they got deleted.
+ val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
+ reportDataLoss(
+ s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
+ }
+ logInfo(s"Partitions added: $newPartitionInitialOffsets")
+ newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
+ reportDataLoss(
+ s"Added partition $p starts from $o instead of 0. Some data may have been missed")
+ }
+
+ // Find deleted partitions, and report data loss if required
+ val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet)
+ if (deletedPartitions.nonEmpty) {
+ reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed")
+ }
+
+ // Use the end partitions to calculate offset ranges to ignore partitions that have
+ // been deleted
+ val topicPartitions = endPartitionOffsets.keySet.filter { tp =>
+ // Ignore partitions that we don't know the from offsets.
+ newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp)
+ }.toSeq
+ logDebug("TopicPartitions: " + topicPartitions.mkString(", "))
+
+ // Calculate offset ranges
+ val offsetRanges = rangeCalculator.getRanges(
+ fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets,
+ untilOffsets = endPartitionOffsets,
+ executorLocations = getSortedExecutorList())
+
+ // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
+ // that is, concurrent tasks will not read the same TopicPartitions.
+ val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size
+
+ // Generate factories based on the offset ranges
+ val factories = offsetRanges.map { range =>
+ new KafkaMicroBatchInputPartition(
+ range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
+ }
+ factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
+ }
+
+ override def getStartOffset: Offset = {
+ KafkaSourceOffset(startPartitionOffsets)
+ }
+
+ override def getEndOffset: Offset = {
+ KafkaSourceOffset(endPartitionOffsets)
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ KafkaSourceOffset(JsonUtils.partitionOffsets(json))
+ }
+
+ override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema
+
+ override def commit(end: Offset): Unit = {}
+
+ override def stop(): Unit = {
+ kafkaOffsetReader.close()
+ }
+
+ override def toString(): String = s"KafkaV2[$kafkaOffsetReader]"
+
+ /**
+ * Read initial partition offsets from the checkpoint, or decide the offsets and write them to
+ * the checkpoint.
+ */
+ private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = {
+ // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread.
+ // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever
+ // (KAFKA-1894).
+ assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
+
+ // SparkSession is required for getting Hadoop configuration for writing to checkpoints
+ assert(SparkSession.getActiveSession.nonEmpty)
+
+ val metadataLog =
+ new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath)
+ metadataLog.get(0).getOrElse {
+ val offsets = startingOffsets match {
+ case EarliestOffsetRangeLimit =>
+ KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets())
+ case LatestOffsetRangeLimit =>
+ KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets())
+ case SpecificOffsetRangeLimit(p) =>
+ kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss)
+ }
+ metadataLog.add(0, offsets)
+ logInfo(s"Initial offsets: $offsets")
+ offsets
+ }.partitionToOffsets
+ }
+
+ /** Proportionally distribute limit number of offsets among topicpartitions */
+ private def rateLimit(
+ limit: Long,
+ from: PartitionOffsetMap,
+ until: PartitionOffsetMap): PartitionOffsetMap = {
+ val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq)
+ val sizes = until.flatMap {
+ case (tp, end) =>
+ // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it
+ from.get(tp).orElse(fromNew.get(tp)).flatMap { begin =>
+ val size = end - begin
+ logDebug(s"rateLimit $tp size is $size")
+ if (size > 0) Some(tp -> size) else None
+ }
+ }
+ val total = sizes.values.sum.toDouble
+ if (total < 1) {
+ until
+ } else {
+ until.map {
+ case (tp, end) =>
+ tp -> sizes.get(tp).map { size =>
+ val begin = from.get(tp).getOrElse(fromNew(tp))
+ val prorate = limit * (size / total)
+ // Don't completely starve small topicpartitions
+ val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong
+ // Paranoia, make sure not to return an offset that's past end
+ Math.min(end, off)
+ }.getOrElse(end)
+ }
+ }
+ }
+
+ private def getSortedExecutorList(): Array[String] = {
+
+ def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = {
+ if (a.host == b.host) {
+ a.executorId > b.executorId
+ } else {
+ a.host > b.host
+ }
+ }
+
+ val bm = SparkEnv.get.blockManager
+ bm.master.getPeers(bm.blockManagerId).toArray
+ .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
+ .sortWith(compare)
+ .map(_.toString)
+ }
+
+ /**
+ * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
+ * Otherwise, just log a warning.
+ */
+ private def reportDataLoss(message: String): Unit = {
+ if (failOnDataLoss) {
+ throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
+ } else {
+ logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
+ }
+ }
+
+ /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */
+ class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String)
+ extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) {
+
+ val VERSION = 1
+
+ override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = {
+ out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517)
+ val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+ writer.write("v" + VERSION + "\n")
+ writer.write(metadata.json)
+ writer.flush
+ }
+
+ override def deserialize(in: InputStream): KafkaSourceOffset = {
+ in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517)
+ val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+ // HDFSMetadataLog guarantees that it never creates a partial file.
+ assert(content.length != 0)
+ if (content(0) == 'v') {
+ val indexOfNewLine = content.indexOf("\n")
+ if (indexOfNewLine > 0) {
+ val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
+ KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ } else {
+ // The log was generated by Spark 2.1.0
+ KafkaSourceOffset(SerializedOffset(content))
+ }
+ }
+ }
+}
+
+/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */
+private[kafka010] case class KafkaMicroBatchInputPartition(
+ offsetRange: KafkaOffsetRange,
+ executorKafkaParams: ju.Map[String, Object],
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean,
+ reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] {
+
+ override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray
+
+ override def createPartitionReader(): InputPartitionReader[UnsafeRow] =
+ new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs,
+ failOnDataLoss, reuseKafkaConsumer)
+}
+
+/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */
+private[kafka010] case class KafkaMicroBatchInputPartitionReader(
+ offsetRange: KafkaOffsetRange,
+ executorKafkaParams: ju.Map[String, Object],
+ pollTimeoutMs: Long,
+ failOnDataLoss: Boolean,
+ reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging {
+
+ private val consumer = KafkaDataConsumer.acquire(
+ offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
+
+ private val rangeToRead = resolveRange(offsetRange)
+ private val converter = new KafkaRecordToUnsafeRowConverter
+
+ private var nextOffset = rangeToRead.fromOffset
+ private var nextRow: UnsafeRow = _
+
+ override def next(): Boolean = {
+ if (nextOffset < rangeToRead.untilOffset) {
+ val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
+ if (record != null) {
+ nextRow = converter.toUnsafeRow(record)
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ }
+
+ override def get(): UnsafeRow = {
+ assert(nextRow != null)
+ nextOffset += 1
+ nextRow
+ }
+
+ override def close(): Unit = {
+ consumer.release()
+ }
+
+ private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = {
+ if (range.fromOffset < 0 || range.untilOffset < 0) {
+ // Late bind the offset range
+ val availableOffsetRange = consumer.getAvailableOffsetRange()
+ val fromOffset = if (range.fromOffset < 0) {
+ assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST,
+ s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}")
+ availableOffsetRange.earliest
+ } else {
+ range.fromOffset
+ }
+ val untilOffset = if (range.untilOffset < 0) {
+ assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST,
+ s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}")
+ availableOffsetRange.latest
+ } else {
+ range.untilOffset
+ }
+ KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None)
+ } else {
+ range
+ }
+ }
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala
new file mode 100644
index 0000000000000..6631ae84167c8
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+
+
+/**
+ * Class to calculate offset ranges to process based on the the from and until offsets, and
+ * the configured `minPartitions`.
+ */
+private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) {
+ require(minPartitions.isEmpty || minPartitions.get > 0)
+
+ import KafkaOffsetRangeCalculator._
+ /**
+ * Calculate the offset ranges that we are going to process this batch. If `minPartitions`
+ * is not set or is set less than or equal the number of `topicPartitions` that we're going to
+ * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If
+ * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up
+ * the read tasks of the skewed partitions to multiple Spark tasks.
+ * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more
+ * depending on rounding errors or Kafka partitions that didn't receive any new data.
+ */
+ def getRanges(
+ fromOffsets: PartitionOffsetMap,
+ untilOffsets: PartitionOffsetMap,
+ executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = {
+ val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet)
+
+ val offsetRanges = partitionsToRead.toSeq.map { tp =>
+ KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None)
+ }.filter(_.size > 0)
+
+ // If minPartitions not set or there are enough partitions to satisfy minPartitions
+ if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) {
+ // Assign preferred executor locations to each range such that the same topic-partition is
+ // preferentially read from the same executor and the KafkaConsumer can be reused.
+ offsetRanges.map { range =>
+ range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations))
+ }
+ } else {
+
+ // Splits offset ranges with relatively large amount of data to smaller ones.
+ val totalSize = offsetRanges.map(_.size).sum
+ val idealRangeSize = totalSize.toDouble / minPartitions.get
+
+ offsetRanges.flatMap { range =>
+ // Split the current range into subranges as close to the ideal range size
+ val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt
+
+ (0 until numSplitsInRange).map { i =>
+ val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange)
+ val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange)
+ KafkaOffsetRange(
+ range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None)
+ }
+ }
+ }
+ }
+
+ private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = {
+ def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b
+
+ val numExecutors = executorLocations.length
+ if (numExecutors > 0) {
+ // This allows cached KafkaConsumers in the executors to be re-used to read the same
+ // partition in every batch.
+ Some(executorLocations(floorMod(tp.hashCode, numExecutors)))
+ } else None
+ }
+}
+
+private[kafka010] object KafkaOffsetRangeCalculator {
+
+ def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = {
+ val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt)
+ new KafkaOffsetRangeCalculator(optionalValue)
+ }
+}
+
+private[kafka010] case class KafkaOffsetRange(
+ topicPartition: TopicPartition,
+ fromOffset: Long,
+ untilOffset: Long,
+ preferredLoc: Option[String]) {
+ lazy val size: Long = untilOffset - fromOffset
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index 551641cfdbca8..82066697cb95a 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader(
* A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
* offsets and never commits them.
*/
- protected var consumer = createConsumer()
+ @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null
+
+ protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized {
+ assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
+ if (_consumer == null) {
+ val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
+ newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
+ _consumer = consumerStrategy.createConsumer(newKafkaParams)
+ }
+ _consumer
+ }
private val maxOffsetFetchAttempts =
readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt
@@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader(
* Closes the connection to Kafka, and cleans up state.
*/
def close(): Unit = {
- runUninterruptibly {
- consumer.close()
- }
+ if (_consumer != null) runUninterruptibly { stopConsumer() }
kafkaReaderThread.shutdown()
}
@@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader(
}
}
- /**
- * Create a consumer using the new generated group id. We always use a new consumer to avoid
- * just using a broken consumer to retry on Kafka errors, which likely will fail again.
- */
- private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
- val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
- newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
- consumerStrategy.createConsumer(newKafkaParams)
+ private def stopConsumer(): Unit = synchronized {
+ assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
+ if (_consumer != null) _consumer.close()
}
private def resetConsumer(): Unit = synchronized {
- consumer.close()
- consumer = createConsumer()
+ stopConsumer()
+ _consumer = null // will automatically get reinitialized again
}
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
new file mode 100644
index 0000000000000..f35a143e00374
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.clients.consumer.ConsumerRecord
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.unsafe.types.UTF8String
+
+/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */
+private[kafka010] class KafkaRecordToUnsafeRowConverter {
+ private val rowWriter = new UnsafeRowWriter(7)
+
+ def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = {
+ rowWriter.reset()
+
+ if (record.key == null) {
+ rowWriter.setNullAt(0)
+ } else {
+ rowWriter.write(0, record.key)
+ }
+ rowWriter.write(1, record.value)
+ rowWriter.write(2, UTF8String.fromString(record.topic))
+ rowWriter.write(3, record.partition)
+ rowWriter.write(4, record.offset)
+ rowWriter.write(
+ 5,
+ DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp)))
+ rowWriter.write(6, record.timestampType.id)
+ rowWriter.getRow()
+ }
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
index 7103709969c18..c31e6ed3e0903 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
@@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation(
private val pollTimeoutMs = sourceOptions.getOrElse(
"kafkaConsumer.pollTimeoutMs",
- sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString
+ (sqlContext.sparkContext.conf.getTimeAsSeconds(
+ "spark.network.timeout",
+ "120s") * 1000L).toString
).toLong
override def schema: StructType = KafkaOffsetReader.kafkaSchema
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 169a5d006fb04..101e649727fcf 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._
+import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -83,7 +84,7 @@ private[kafka010] class KafkaSource(
private val pollTimeoutMs = sourceOptions.getOrElse(
"kafkaConsumer.pollTimeoutMs",
- sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString
+ (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString
).toLong
private val maxOffsetsPerTrigger =
@@ -306,7 +307,7 @@ private[kafka010] class KafkaSource(
kafkaReader.close()
}
- override def toString(): String = s"KafkaSource[$kafkaReader]"
+ override def toString(): String = s"KafkaSourceV1[$kafkaReader]"
/**
* If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
@@ -323,22 +324,6 @@ private[kafka010] class KafkaSource(
/** Companion object for the [[KafkaSource]]. */
private[kafka010] object KafkaSource {
- val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE =
- """
- |Some data may have been lost because they are not available in Kafka any more; either the
- | data was aged out by Kafka or the topic may have been deleted before all the data in the
- | topic was processed. If you want your streaming query to fail on such cases, set the source
- | option "failOnDataLoss" to "true".
- """.stripMargin
-
- val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE =
- """
- |Some data may have been lost because they are not available in Kafka any more; either the
- | data was aged out by Kafka or the topic may have been deleted before all the data in the
- | topic was processed. If you don't want your streaming query to fail on such cases, set the
- | source option "failOnDataLoss" to "false".
- """.stripMargin
-
private[kafka010] val VERSION = 1
def getSortedExecutorList(sc: SparkContext): Array[String] = {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 694ca76e24964..d225c1ea6b7f1 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -30,15 +30,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport
-import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
/**
- * The provider class for the [[KafkaSource]]. This provider is designed such that it throws
+ * The provider class for all Kafka readers and writers. It is designed such that it throws
* IllegalArgumentException when the Kafka Dataset is created, so that it can catch
* missing options even before the query is started.
*/
@@ -49,6 +48,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with CreatableRelationProvider
with StreamWriteSupport
with ContinuousReadSupport
+ with MicroBatchReadSupport
with Logging {
import KafkaSourceProvider._
@@ -107,6 +107,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
failOnDataLoss(caseInsensitiveParams))
}
+ /**
+ * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches
+ * of Kafka data in a micro-batch streaming query.
+ */
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ metadataPath: String,
+ options: DataSourceOptions): KafkaMicroBatchReader = {
+
+ val parameters = options.asMap().asScala.toMap
+ validateStreamOptions(parameters)
+ // Each running query should use its own group id. Otherwise, the query may be only assigned
+ // partial data since Kafka will assign partitions to multiple consumers having the same group
+ // id. Hence, we should generate a unique id for each query.
+ val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+ val specifiedKafkaParams =
+ parameters
+ .keySet
+ .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap
+
+ val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
+ STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+
+ val kafkaOffsetReader = new KafkaOffsetReader(
+ strategy(caseInsensitiveParams),
+ kafkaParamsForDriver(specifiedKafkaParams),
+ parameters,
+ driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
+ new KafkaMicroBatchReader(
+ kafkaOffsetReader,
+ kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+ options,
+ metadataPath,
+ startingStreamOffsets,
+ failOnDataLoss(caseInsensitiveParams))
+ }
+
+ /**
+ * Creates a [[ContinuousInputPartitionReader]] to read
+ * Kafka data in a continuous streaming query.
+ */
override def createContinuousReader(
schema: Optional[StructType],
metadataPath: String,
@@ -303,6 +349,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
throw new IllegalArgumentException("Unknown option")
}
+ // Validate minPartitions value if present
+ if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) {
+ val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt
+ if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive")
+ }
+
// Validate user-specified Kafka options
if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) {
@@ -410,8 +462,28 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
+ private val MIN_PARTITIONS_OPTION_KEY = "minpartitions"
+
val TOPIC_OPTION_KEY = "topic"
+ val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE =
+ """
+ |Some data may have been lost because they are not available in Kafka any more; either the
+ | data was aged out by Kafka or the topic may have been deleted before all the data in the
+ | topic was processed. If you want your streaming query to fail on such cases, set the source
+ | option "failOnDataLoss" to "true".
+ """.stripMargin
+
+ val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE =
+ """
+ |Some data may have been lost because they are not available in Kafka any more; either the
+ | data was aged out by Kafka or the topic may have been deleted before all the data in the
+ | topic was processed. If you don't want your streaming query to fail on such cases, set the
+ | source option "failOnDataLoss" to "false".
+ """.stripMargin
+
+
+
private val deserClassName = classOf[ByteArrayDeserializer].getName
def getKafkaOffsetRangeLimit(
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 66b3409c0cd04..498e344ea39f4 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition(
* An RDD that reads data from Kafka based on offset ranges across multiple partitions.
* Additionally, it allows preferred locations to be set for each topic + partition, so that
* the [[KafkaSource]] can ensure the same executor always reads the same topic + partition
- * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently.
+ * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently.
*
* @param sc the [[SparkContext]]
* @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors
@@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD(
val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
val topic = sourcePartition.offsetRange.topic
val kafkaPartition = sourcePartition.offsetRange.partition
- val consumer =
- if (!reuseKafkaConsumer) {
- // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we
- // uses `assign`, we don't need to worry about the "group.id" conflicts.
- CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams)
- } else {
- CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
- }
+ val consumer = KafkaDataConsumer.acquire(
+ sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer)
+
val range = resolveRange(consumer, sourcePartition.offsetRange)
assert(
range.fromOffset <= range.untilOffset,
@@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD(
}
override protected def close(): Unit = {
- if (!reuseKafkaConsumer) {
- // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
- consumer.close()
- } else {
- // Indicate that we're no longer using this consumer
- CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
- }
+ consumer.release()
}
}
// Release consumer, either by removing it or indicating we're no longer using it
@@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD(
}
}
- private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = {
+ private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = {
if (range.fromOffset < 0 || range.untilOffset < 0) {
// Late bind the offset range
val availableOffsetRange = consumer.getAvailableOffsetRange()
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
index 9307bfc001c03..ae5b5c52d514e 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
@@ -65,7 +65,10 @@ case class KafkaStreamWriterFactory(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends DataWriterFactory[InternalRow] {
- override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): DataWriter[InternalRow] = {
new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
similarity index 70%
rename from core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
index b1217980faf1f..43acd6a8d9473 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala
@@ -14,18 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql
-package org.apache.spark.util
-
-/**
- * Used for shipping per-thread stacktraces from the executors to driver.
- */
-private[spark] case class ThreadStackTrace(
- threadId: Long,
- threadName: String,
- threadState: Thread.State,
- stackTrace: String,
- blockedByThreadId: Option[Long],
- blockedByLock: String,
- holdingLocks: Seq[String])
+import org.apache.kafka.common.TopicPartition
+package object kafka010 { // scalastyle:ignore
+ // ^^ scalastyle:ignore is for ignoring warnings about digits in package name
+ type PartitionOffsetMap = Map[TopicPartition, Long]
+}
diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin
new file mode 100644
index 0000000000000..d530773f57327
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin
@@ -0,0 +1,2 @@
+0v99999
+{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}}
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin
index ae928e724967d..8c78d9e390a0e 100644
--- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin
+++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin
@@ -1 +1 @@
-2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}}
\ No newline at end of file
+2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}}
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
index fc890a0cfdac3..ddfc0c1a4be2d 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
@@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
val reader = createKafkaReader(topic)
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
- .as[(Int, Int)]
+ .as[(Option[Int], Int)]
.map(_._2)
try {
@@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
val reader = createKafkaReader(topic)
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
- .as[(Int, Int)]
+ .as[(Option[Int], Int)]
.map(_._2)
try {
@@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
val reader = createKafkaReader(topic)
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
- .as[(Int, Int)]
+ .as[(Option[Int], Int)]
.map(_._2)
try {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
index a7083fa4e3417..aab8ec42189fb 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
@@ -17,20 +17,9 @@
package org.apache.spark.sql.kafka010
-import java.util.Properties
-import java.util.concurrent.atomic.AtomicInteger
-
-import org.scalatest.time.SpanSugar._
-import scala.collection.mutable
-import scala.util.Random
-
-import org.apache.spark.SparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
-import org.apache.spark.sql.streaming.{StreamTest, Trigger}
-import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
+import org.apache.spark.sql.streaming.Trigger
// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
@@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
- case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+ case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r
}.exists { r =>
// Ensure the new topic is present and the old topic is gone.
r.knownPartitions.exists(_.topic == topic2)
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
index 5a1a14f7a307a..fa1468a3943c8 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.Trigger
@@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
- case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+ case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000000000..0d0fb9c3ab5af
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.ThreadUtils
+
+class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester {
+
+ protected var testUtils: KafkaTestUtils = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ testUtils = new KafkaTestUtils(Map[String, Object]())
+ testUtils.setup()
+ }
+
+ override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.teardown()
+ testUtils = null
+ }
+ super.afterAll()
+ }
+
+ test("SPARK-19886: Report error cause correctly in reportDataLoss") {
+ val cause = new Exception("D'oh!")
+ val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
+ val e = intercept[IllegalStateException] {
+ InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause))
+ }
+ assert(e.getCause === cause)
+ }
+
+ test("SPARK-23623: concurrent use of KafkaDataConsumer") {
+ val topic = "topic" + Random.nextInt()
+ val data = (1 to 1000).map(_.toString)
+ testUtils.createTopic(topic, 1)
+ testUtils.sendMessages(topic, data.toArray)
+ val topicPartition = new TopicPartition(topic, 0)
+
+ import ConsumerConfig._
+ val kafkaParams = Map[String, Object](
+ GROUP_ID_CONFIG -> "groupId",
+ BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+ KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ AUTO_OFFSET_RESET_CONFIG -> "earliest",
+ ENABLE_AUTO_COMMIT_CONFIG -> "false"
+ )
+
+ val numThreads = 100
+ val numConsumerUsages = 500
+
+ @volatile var error: Throwable = null
+
+ def consume(i: Int): Unit = {
+ val useCache = Random.nextBoolean
+ val taskContext = if (Random.nextBoolean) {
+ new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
+ } else {
+ null
+ }
+ TaskContext.setTaskContext(taskContext)
+ val consumer = KafkaDataConsumer.acquire(
+ topicPartition, kafkaParams.asJava, useCache)
+ try {
+ val range = consumer.getAvailableOffsetRange()
+ val rcvd = range.earliest until range.latest map { offset =>
+ val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value()
+ new String(bytes)
+ }
+ assert(rcvd == data)
+ } catch {
+ case e: Throwable =>
+ error = e
+ throw e
+ } finally {
+ consumer.release()
+ }
+ }
+
+ val threadpool = Executors.newFixedThreadPool(numThreads)
+ try {
+ val futures = (1 to numConsumerUsages).map { i =>
+ threadpool.submit(new Runnable {
+ override def run(): Unit = { consume(i) }
+ })
+ }
+ futures.foreach(_.get(1, TimeUnit.MINUTES))
+ assert(error == null)
+ } finally {
+ threadpool.shutdown()
+ }
+ }
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
similarity index 80%
rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 02c87643568bd..c6412eac97dba 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -20,11 +20,13 @@ package org.apache.spark.sql.kafka010
import java.io._
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files, Paths}
-import java.util.{Locale, Properties}
+import java.util.{Locale, Optional, Properties}
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.io.Source
import scala.util.Random
import org.apache.kafka.clients.producer.RecordMetadata
@@ -33,16 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext
-import org.apache.spark.sql.{Dataset, ForeachWriter}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
+import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.types.StructType
abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
@@ -112,14 +117,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
query.nonEmpty,
"Cannot add data when there is no query for finding the active kafka source")
- val sources = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: KafkaSource, _) => source
- } ++ (query.get.lastExecution match {
- case null => Seq()
- case e => e.logical.collect {
- case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
- }
- })
+ val sources = {
+ query.get.logicalPlan.collect {
+ case StreamingExecutionRelation(source: KafkaSource, _) => source
+ case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source
+ } ++ (query.get.lastExecution match {
+ case null => Seq()
+ case e => e.logical.collect {
+ case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader
+ }
+ })
+ }.distinct
+
if (sources.isEmpty) {
throw new Exception(
"Could not find Kafka source in the StreamExecution logical plan to add data to")
@@ -155,7 +164,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
}
-class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
+abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
import testImplicits._
@@ -303,94 +312,105 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
)
}
- testWithUninterruptibleThread(
- "deserialization of initial offset with Spark 2.1.0") {
+ test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") {
withTempDir { metadataPath =>
- val topic = newTopic
- testUtils.createTopic(topic, partitions = 3)
+ val topic = "kafka-initial-offset-current"
+ testUtils.createTopic(topic, partitions = 1)
- val provider = new KafkaSourceProvider
- val parameters = Map(
- "kafka.bootstrap.servers" -> testUtils.brokerAddress,
- "subscribe" -> topic
- )
- val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None,
- "", parameters)
- source.getOffset.get // Write initial offset
-
- // Make sure Spark 2.1.0 will throw an exception when reading the new log
- intercept[java.lang.IllegalArgumentException] {
- // Simulate how Spark 2.1.0 reads the log
- Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in =>
- val length = in.read()
- val bytes = new Array[Byte](length)
- in.read(bytes)
- KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8)))
- }
+ val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile
+
+ val df = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", s"earliest")
+ .load()
+
+ // Test the written initial offset file has 0 byte in the beginning, so that
+ // Spark 2.1.0 can read the offsets (see SPARK-19517)
+ testStream(df)(
+ StartStream(checkpointLocation = metadataPath.getAbsolutePath),
+ makeSureGetOffsetCalled)
+
+ val binarySource = Source.fromFile(initialOffsetFile)
+ try {
+ assert(binarySource.next().toInt == 0) // first byte is binary 0
+ } finally {
+ binarySource.close()
}
}
}
- testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") {
+ test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") {
withTempDir { metadataPath =>
val topic = "kafka-initial-offset-2-1-0"
testUtils.createTopic(topic, partitions = 3)
+ testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1))
+ testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2))
- val provider = new KafkaSourceProvider
- val parameters = Map(
- "kafka.bootstrap.servers" -> testUtils.brokerAddress,
- "subscribe" -> topic
- )
-
+ // Copy the initial offset file into the right location inside the checkpoint root directory
+ // such that the Kafka source can read it for initial offsets.
val from = new File(
getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath
- val to = Paths.get(s"${metadataPath.getAbsolutePath}/0")
+ val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0")
+ Files.createDirectories(to.getParent)
Files.copy(from, to)
- val source = provider.createSource(
- spark.sqlContext, metadataPath.toURI.toString, None, "", parameters)
- val deserializedOffset = source.getOffset.get
- val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L))
- assert(referenceOffset == deserializedOffset)
+ val df = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", s"earliest")
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+
+ // Test that the query starts from the expected initial offset (i.e. read older offsets,
+ // even though startingOffsets is latest).
+ testStream(df)(
+ StartStream(checkpointLocation = metadataPath.getAbsolutePath),
+ AddKafkaData(Set(topic), 1000),
+ CheckAnswer(0, 1, 2, 10, 20, 200, 1000))
}
}
- testWithUninterruptibleThread("deserialization of initial offset written by future version") {
+ test("deserialization of initial offset written by future version") {
withTempDir { metadataPath =>
- val futureMetadataLog =
- new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession,
- metadataPath.getAbsolutePath) {
- override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = {
- out.write(0)
- val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8))
- writer.write(s"v99999\n${metadata.json}")
- writer.flush
- }
- }
-
- val topic = newTopic
+ val topic = "kafka-initial-offset-future-version"
testUtils.createTopic(topic, partitions = 3)
- val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L))
- futureMetadataLog.add(0, offset)
- val provider = new KafkaSourceProvider
- val parameters = Map(
- "kafka.bootstrap.servers" -> testUtils.brokerAddress,
- "subscribe" -> topic
- )
- val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None,
- "", parameters)
-
- val e = intercept[java.lang.IllegalStateException] {
- source.getOffset.get // Read initial offset
- }
+ // Copy the initial offset file into the right location inside the checkpoint root directory
+ // such that the Kafka source can read it for initial offsets.
+ val from = new File(
+ getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath
+ val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0")
+ Files.createDirectories(to.getParent)
+ Files.copy(from, to)
- Seq(
- s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999",
- "produced by a newer version of Spark and cannot be read by this version"
- ).foreach { message =>
- assert(e.getMessage.contains(message))
- }
+ val df = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+
+ testStream(df)(
+ StartStream(checkpointLocation = metadataPath.getAbsolutePath),
+ ExpectFailure[IllegalStateException](e => {
+ Seq(
+ s"maximum supported log version is v1, but encountered v99999",
+ "produced by a newer version of Spark and cannot be read by this version"
+ ).foreach { message =>
+ assert(e.toString.contains(message))
+ }
+ }))
}
}
@@ -542,6 +562,143 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
CheckLastBatch(120 to 124: _*)
)
}
+
+ test("ensure stream-stream self-join generates only one offset in log and correct metrics") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+ require(testUtils.getLatestOffsets(Set(topic)).size === 2)
+
+ val kafka = spark
+ .readStream
+ .format("kafka")
+ .option("subscribe", topic)
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .load()
+
+ val values = kafka
+ .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value",
+ "CAST(CAST(value AS STRING) AS INT) % 5 AS key")
+
+ val join = values.join(values, "key")
+
+ testStream(join)(
+ makeSureGetOffsetCalled,
+ AddKafkaData(Set(topic), 1, 2),
+ CheckAnswer((1, 1, 1), (2, 2, 2)),
+ AddKafkaData(Set(topic), 6, 3),
+ CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)),
+ AssertOnQuery { q =>
+ assert(q.availableOffsets.iterator.size == 1)
+ assert(q.recentProgress.map(_.numInputRows).sum == 4)
+ true
+ }
+ )
+ }
+}
+
+
+class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(
+ "spark.sql.streaming.disabledV2MicroBatchReaders",
+ classOf[KafkaSourceProvider].getCanonicalName)
+ }
+
+ test("V1 Source is used when disabled through SQLConf") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 5)
+
+ val kafka = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", s"$topic.*")
+ .load()
+
+ testStream(kafka)(
+ makeSureGetOffsetCalled,
+ AssertOnQuery { query =>
+ query.logicalPlan.collect {
+ case StreamingExecutionRelation(_: KafkaSource, _) => true
+ }.nonEmpty
+ }
+ )
+ }
+}
+
+class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
+
+ test("V2 Source is used by default") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 5)
+
+ val kafka = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", s"$topic.*")
+ .load()
+
+ testStream(kafka)(
+ makeSureGetOffsetCalled,
+ AssertOnQuery { query =>
+ query.logicalPlan.collect {
+ case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true
+ }.nonEmpty
+ }
+ )
+ }
+
+ testWithUninterruptibleThread("minPartitions is supported") {
+ import testImplicits._
+
+ val topic = newTopic()
+ val tp = new TopicPartition(topic, 0)
+ testUtils.createTopic(topic, partitions = 1)
+
+ def test(
+ minPartitions: String,
+ numPartitionsGenerated: Int,
+ reusesConsumers: Boolean): Unit = {
+
+ SparkSession.setActiveSession(spark)
+ withTempDir { dir =>
+ val provider = new KafkaSourceProvider()
+ val options = Map(
+ "kafka.bootstrap.servers" -> testUtils.brokerAddress,
+ "subscribe" -> topic
+ ) ++ Option(minPartitions).map { p => "minPartitions" -> p}
+ val reader = provider.createMicroBatchReader(
+ Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava))
+ reader.setOffsetRange(
+ Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
+ Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
+ )
+ val factories = reader.planUnsafeInputPartitions().asScala
+ .map(_.asInstanceOf[KafkaMicroBatchInputPartition])
+ withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
+ assert(factories.size == numPartitionsGenerated)
+ factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) }
+ }
+ }
+ }
+
+ // Test cases when minPartitions is used and not used
+ test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true)
+ test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true)
+ test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false)
+
+ // Test illegal minPartitions values
+ intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) }
+ intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) }
+ intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) }
+ intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) }
+ }
+
}
abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala
new file mode 100644
index 0000000000000..2ccf3e291bea7
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+
+class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite {
+
+ def testWithMinPartitions(name: String, minPartition: Int)
+ (f: KafkaOffsetRangeCalculator => Unit): Unit = {
+ val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava)
+ test(s"with minPartition = $minPartition: $name") {
+ f(KafkaOffsetRangeCalculator(options))
+ }
+ }
+
+
+ test("with no minPartition: N TopicPartitions to N offset ranges") {
+ val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty())
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1),
+ untilOffsets = Map(tp1 -> 2)) ==
+ Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1),
+ untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) ==
+ Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+ untilOffsets = Map(tp1 -> 2)) ==
+ Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+ untilOffsets = Map(tp1 -> 2),
+ executorLocations = Seq("location")) ==
+ Seq(KafkaOffsetRange(tp1, 1, 2, Some("location"))))
+ }
+
+ test("with no minPartition: empty ranges ignored") {
+ val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty())
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+ untilOffsets = Map(tp1 -> 2, tp2 -> 1)) ==
+ Seq(KafkaOffsetRange(tp1, 1, 2, None)))
+ }
+
+ testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc =>
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1),
+ untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 2, None),
+ KafkaOffsetRange(tp2, 1, 2, None),
+ KafkaOffsetRange(tp3, 1, 2, None)))
+ }
+
+ testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc =>
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1),
+ untilOffsets = Map(tp1 -> 5)) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 2, None),
+ KafkaOffsetRange(tp1, 2, 3, None),
+ KafkaOffsetRange(tp1, 3, 4, None),
+ KafkaOffsetRange(tp1, 4, 5, None)))
+
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1),
+ untilOffsets = Map(tp1 -> 5),
+ executorLocations = Seq("location")) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 2, None),
+ KafkaOffsetRange(tp1, 2, 3, None),
+ KafkaOffsetRange(tp1, 3, 4, None),
+ KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set
+ }
+
+ testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc =>
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1),
+ untilOffsets = Map(tp1 -> 5, tp2 -> 21)) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 5, None),
+ KafkaOffsetRange(tp2, 1, 7, None),
+ KafkaOffsetRange(tp2, 7, 14, None),
+ KafkaOffsetRange(tp2, 14, 21, None)))
+ }
+
+ testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc =>
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1),
+ untilOffsets = Map(tp1 -> 11)) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 4, None),
+ KafkaOffsetRange(tp1, 4, 7, None),
+ KafkaOffsetRange(tp1, 7, 11, None)))
+ }
+
+ testWithMinPartitions("empty ranges ignored", 3) { calc =>
+ assert(
+ calc.getRanges(
+ fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1),
+ untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) ==
+ Seq(
+ KafkaOffsetRange(tp1, 1, 5, None),
+ KafkaOffsetRange(tp2, 1, 7, None),
+ KafkaOffsetRange(tp2, 7, 14, None),
+ KafkaOffsetRange(tp2, 14, 21, None)))
+ }
+
+ private val tp1 = new TopicPartition("t1", 1)
+ private val tp2 = new TopicPartition("t2", 1)
+ private val tp3 = new TopicPartition("t3", 1)
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index 42f8b4c7657e2..7079ac6453ffc 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext {
val reader = createKafkaReader(topic)
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
- .as[(Int, Int)]
+ .as[(Option[Int], Int)]
.map(_._2)
try {
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
deleted file mode 100644
index fa3ea6131a507..0000000000000
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
+++ /dev/null
@@ -1,189 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka010
-
-import java.{ util => ju }
-
-import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer }
-import org.apache.kafka.common.{ KafkaException, TopicPartition }
-
-import org.apache.spark.SparkConf
-import org.apache.spark.internal.Logging
-
-
-/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is usually bad anyway.
- */
-private[kafka010]
-class CachedKafkaConsumer[K, V] private(
- val groupId: String,
- val topic: String,
- val partition: Int,
- val kafkaParams: ju.Map[String, Object]) extends Logging {
-
- assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG),
- "groupId used for cache key must match the groupId in kafkaParams")
-
- val topicPartition = new TopicPartition(topic, partition)
-
- protected val consumer = {
- val c = new KafkaConsumer[K, V](kafkaParams)
- val tps = new ju.ArrayList[TopicPartition]()
- tps.add(topicPartition)
- c.assign(tps)
- c
- }
-
- // TODO if the buffer was kept around as a random-access structure,
- // could possibly optimize re-calculating of an RDD in the same batch
- protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator
- protected var nextOffset = -2L
-
- def close(): Unit = consumer.close()
-
- /**
- * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
- * Sequential forward access will use buffers, but random access will be horribly inefficient.
- */
- def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
- logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset")
- if (offset != nextOffset) {
- logInfo(s"Initial fetch for $groupId $topic $partition $offset")
- seek(offset)
- poll(timeout)
- }
-
- if (!buffer.hasNext()) { poll(timeout) }
- assert(buffer.hasNext(),
- s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
- var record = buffer.next()
-
- if (record.offset != offset) {
- logInfo(s"Buffer miss for $groupId $topic $partition $offset")
- seek(offset)
- poll(timeout)
- assert(buffer.hasNext(),
- s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
- record = buffer.next()
- assert(record.offset == offset,
- s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset")
- }
-
- nextOffset = offset + 1
- record
- }
-
- private def seek(offset: Long): Unit = {
- logDebug(s"Seeking to $topicPartition $offset")
- consumer.seek(topicPartition, offset)
- }
-
- private def poll(timeout: Long): Unit = {
- val p = consumer.poll(timeout)
- val r = p.records(topicPartition)
- logDebug(s"Polled ${p.partitions()} ${r.size}")
- buffer = r.iterator
- }
-
-}
-
-private[kafka010]
-object CachedKafkaConsumer extends Logging {
-
- private case class CacheKey(groupId: String, topic: String, partition: Int)
-
- // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
- private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null
-
- /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */
- def init(
- initialCapacity: Int,
- maxCapacity: Int,
- loadFactor: Float): Unit = CachedKafkaConsumer.synchronized {
- if (null == cache) {
- logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
- cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]](
- initialCapacity, loadFactor, true) {
- override def removeEldestEntry(
- entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = {
- if (this.size > maxCapacity) {
- try {
- entry.getValue.consumer.close()
- } catch {
- case x: KafkaException =>
- logError("Error closing oldest Kafka consumer", x)
- }
- true
- } else {
- false
- }
- }
- }
- }
- }
-
- /**
- * Get a cached consumer for groupId, assigned to topic and partition.
- * If matching consumer doesn't already exist, will be created using kafkaParams.
- */
- def get[K, V](
- groupId: String,
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
- CachedKafkaConsumer.synchronized {
- val k = CacheKey(groupId, topic, partition)
- val v = cache.get(k)
- if (null == v) {
- logInfo(s"Cache miss for $k")
- logDebug(cache.keySet.toString)
- val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
- cache.put(k, c)
- c
- } else {
- // any given topicpartition should have a consistent key and value type
- v.asInstanceOf[CachedKafkaConsumer[K, V]]
- }
- }
-
- /**
- * Get a fresh new instance, unassociated with the global cache.
- * Caller is responsible for closing
- */
- def getUncached[K, V](
- groupId: String,
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
- new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
-
- /** remove consumer for given groupId, topic, and partition, if it exists */
- def remove(groupId: String, topic: String, partition: Int): Unit = {
- val k = CacheKey(groupId, topic, partition)
- logInfo(s"Removing $k from cache")
- val v = CachedKafkaConsumer.synchronized {
- cache.remove(k)
- }
- if (null != v) {
- v.close()
- logInfo(s"Removed $k from cache")
- }
- }
-}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
index 0fa3287f36db8..c3221481556f5 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
@@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V](
ppc: PerPartitionConfig
) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets {
+ private val initialRate = context.sparkContext.getConf.getLong(
+ "spark.streaming.backpressure.initialRate", 0)
+
val executorKafkaParams = {
val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams)
KafkaUtils.fixKafkaParams(ekp)
@@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V](
protected[streaming] def maxMessagesPerPartition(
offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = {
- val estimatedRateLimit = rateController.map(_.getLatestRate())
+ val estimatedRateLimit = rateController.map { x => {
+ val lr = x.getLatestRate()
+ if (lr > 0) lr else initialRate
+ }}
// calculate a per-partition rate limit based on current lag
val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
@@ -138,17 +144,17 @@ private[spark] class DirectKafkaInputDStream[K, V](
lagPerPartition.map { case (tp, lag) =>
val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
- val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+ val backpressureRate = lag / totalLag.toDouble * rate
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
- case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) }
+ case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble }
}
if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
Some(effectiveRateLimitPerPartition.map {
- case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+ case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L)
})
} else {
None
@@ -184,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V](
// make sure new partitions are reflected in currentOffsets
val newPartitions = parts.diff(currentOffsets.keySet)
+
+ // Check if there's any partition been revoked because of consumer rebalance.
+ val revokedPartitions = currentOffsets.keySet.diff(parts)
+ if (revokedPartitions.nonEmpty) {
+ throw new IllegalStateException(s"Previously tracked partitions " +
+ s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " +
+ s"rebalance. This is mostly due to another stream with same group id joined, " +
+ s"please check if there're different streaming application misconfigure to use same " +
+ s"group id. Fundamentally different stream should use different group id")
+ }
+
// position for new partitions determined by auto.offset.reset if no commit
currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap
+
// don't want to consume messages, so pause
c.pause(newPartitions.asJava)
// find latest available offsets
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
new file mode 100644
index 0000000000000..68c5fe9ab066a
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer}
+import org.apache.kafka.common.{KafkaException, TopicPartition}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+
+private[kafka010] sealed trait KafkaDataConsumer[K, V] {
+ /**
+ * Get the record for the given offset if available.
+ *
+ * @param offset the offset to fetch.
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ internalConsumer.get(offset, pollTimeoutMs)
+ }
+
+ /**
+ * Start a batch on a compacted topic
+ *
+ * @param offset the offset to fetch.
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+ internalConsumer.compactedStart(offset, pollTimeoutMs)
+ }
+
+ /**
+ * Get the next record in the batch from a compacted topic.
+ * Assumes compactedStart has been called first, and ignores gaps.
+ *
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ internalConsumer.compactedNext(pollTimeoutMs)
+ }
+
+ /**
+ * Rewind to previous record in the batch from a compacted topic.
+ *
+ * @throws NoSuchElementException if no previous element
+ */
+ def compactedPrevious(): ConsumerRecord[K, V] = {
+ internalConsumer.compactedPrevious()
+ }
+
+ /**
+ * Release this consumer from being further used. Depending on its implementation,
+ * this consumer will be either finalized, or reset for reuse later.
+ */
+ def release(): Unit
+
+ /** Reference to the internal implementation that this wrapper delegates to */
+ def internalConsumer: InternalKafkaConsumer[K, V]
+}
+
+
+/**
+ * A wrapper around Kafka's KafkaConsumer.
+ * This is not for direct use outside this file.
+ */
+private[kafka010] class InternalKafkaConsumer[K, V](
+ val topicPartition: TopicPartition,
+ val kafkaParams: ju.Map[String, Object]) extends Logging {
+
+ private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG)
+ .asInstanceOf[String]
+
+ private val consumer = createConsumer
+
+ /** indicates whether this consumer is in use or not */
+ var inUse = true
+
+ /** indicate whether this consumer is going to be stopped in the next release */
+ var markedForClose = false
+
+ // TODO if the buffer was kept around as a random-access structure,
+ // could possibly optimize re-calculating of an RDD in the same batch
+ @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
+ @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET
+
+ override def toString: String = {
+ "InternalKafkaConsumer(" +
+ s"hash=${Integer.toHexString(hashCode)}, " +
+ s"groupId=$groupId, " +
+ s"topicPartition=$topicPartition)"
+ }
+
+ /** Create a KafkaConsumer to fetch records for `topicPartition` */
+ private def createConsumer: KafkaConsumer[K, V] = {
+ val c = new KafkaConsumer[K, V](kafkaParams)
+ val topics = ju.Arrays.asList(topicPartition)
+ c.assign(topics)
+ c
+ }
+
+ def close(): Unit = consumer.close()
+
+ /**
+ * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
+ * Sequential forward access will use buffers, but random access will be horribly inefficient.
+ */
+ def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
+ logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset")
+ if (offset != nextOffset) {
+ logInfo(s"Initial fetch for $groupId $topicPartition $offset")
+ seek(offset)
+ poll(timeout)
+ }
+
+ if (!buffer.hasNext()) {
+ poll(timeout)
+ }
+ require(buffer.hasNext(),
+ s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+ var record = buffer.next()
+
+ if (record.offset != offset) {
+ logInfo(s"Buffer miss for $groupId $topicPartition $offset")
+ seek(offset)
+ poll(timeout)
+ require(buffer.hasNext(),
+ s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+ record = buffer.next()
+ require(record.offset == offset,
+ s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " +
+ s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " +
+ "spark.streaming.kafka.allowNonConsecutiveOffsets"
+ )
+ }
+
+ nextOffset = offset + 1
+ record
+ }
+
+ /**
+ * Start a batch on a compacted topic
+ */
+ def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+ logDebug(s"compacted start $groupId $topicPartition starting $offset")
+ // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics
+ if (offset != nextOffset) {
+ logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset")
+ seek(offset)
+ poll(pollTimeoutMs)
+ }
+ }
+
+ /**
+ * Get the next record in the batch from a compacted topic.
+ * Assumes compactedStart has been called first, and ignores gaps.
+ */
+ def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ if (!buffer.hasNext()) {
+ poll(pollTimeoutMs)
+ }
+ require(buffer.hasNext(),
+ s"Failed to get records for compacted $groupId $topicPartition " +
+ s"after polling for $pollTimeoutMs")
+ val record = buffer.next()
+ nextOffset = record.offset + 1
+ record
+ }
+
+ /**
+ * Rewind to previous record in the batch from a compacted topic.
+ * @throws NoSuchElementException if no previous element
+ */
+ def compactedPrevious(): ConsumerRecord[K, V] = {
+ buffer.previous()
+ }
+
+ private def seek(offset: Long): Unit = {
+ logDebug(s"Seeking to $topicPartition $offset")
+ consumer.seek(topicPartition, offset)
+ }
+
+ private def poll(timeout: Long): Unit = {
+ val p = consumer.poll(timeout)
+ val r = p.records(topicPartition)
+ logDebug(s"Polled ${p.partitions()} ${r.size}")
+ buffer = r.listIterator
+ }
+
+}
+
+private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition)
+
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+ private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+ extends KafkaDataConsumer[K, V] {
+ assert(internalConsumer.inUse)
+ override def release(): Unit = KafkaDataConsumer.release(internalConsumer)
+ }
+
+ private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+ extends KafkaDataConsumer[K, V] {
+ override def release(): Unit = internalConsumer.close()
+ }
+
+ // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
+ private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null
+
+ /**
+ * Must be called before acquire, once per JVM, to configure the cache.
+ * Further calls are ignored.
+ */
+ def init(
+ initialCapacity: Int,
+ maxCapacity: Int,
+ loadFactor: Float): Unit = synchronized {
+ if (null == cache) {
+ logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
+ cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]](
+ initialCapacity, loadFactor, true) {
+ override def removeEldestEntry(
+ entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = {
+
+ // Try to remove the least-used entry if its currently not in use.
+ //
+ // If you cannot remove it, then the cache will keep growing. In the worst case,
+ // the cache will grow to the max number of concurrent tasks that can run in the executor,
+ // (that is, number of tasks slots) after which it will never reduce. This is unlikely to
+ // be a serious problem because an executor with more than 64 (default) tasks slots is
+ // likely running on a beefy machine that can handle a large number of simultaneously
+ // active consumers.
+
+ if (entry.getValue.inUse == false && this.size > maxCapacity) {
+ logWarning(
+ s"KafkaConsumer cache hitting max capacity of $maxCapacity, " +
+ s"removing consumer for ${entry.getKey}")
+ try {
+ entry.getValue.close()
+ } catch {
+ case x: KafkaException =>
+ logError("Error closing oldest Kafka consumer", x)
+ }
+ true
+ } else {
+ false
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Get a cached consumer for groupId, assigned to topic and partition.
+ * If matching consumer doesn't already exist, will be created using kafkaParams.
+ * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]].
+ *
+ * Note: This method guarantees that the consumer returned is not currently in use by anyone
+ * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by
+ * caching them and tracking when they are in use.
+ */
+ def acquire[K, V](
+ topicPartition: TopicPartition,
+ kafkaParams: ju.Map[String, Object],
+ context: TaskContext,
+ useCache: Boolean): KafkaDataConsumer[K, V] = synchronized {
+ val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+ val key = new CacheKey(groupId, topicPartition)
+ val existingInternalConsumer = cache.get(key)
+
+ lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams)
+
+ if (context != null && context.attemptNumber >= 1) {
+ // If this is reattempt at running the task, then invalidate cached consumers if any and
+ // start with a new one. If prior attempt failures were cache related then this way old
+ // problematic consumers can be removed.
+ logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer")
+ if (existingInternalConsumer != null) {
+ // Consumer exists in cache. If its in use, mark it for closing later, or close it now.
+ if (existingInternalConsumer.inUse) {
+ existingInternalConsumer.markedForClose = true
+ } else {
+ existingInternalConsumer.close()
+ // Remove the consumer from cache only if it's closed.
+ // Marked for close consumers will be removed in release function.
+ cache.remove(key)
+ }
+ }
+
+ logDebug("Reattempt detected, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else if (!useCache) {
+ // If consumer reuse turned off, then do not use it, return a new consumer
+ logDebug("Cache usage turned off, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else if (existingInternalConsumer == null) {
+ // If consumer is not already cached, then put a new in the cache and return it
+ logDebug("No cached consumer, new cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ cache.put(key, newInternalConsumer)
+ CachedKafkaDataConsumer(newInternalConsumer)
+ } else if (existingInternalConsumer.inUse) {
+ // If consumer is already cached but is currently in use, then return a new consumer
+ logDebug("Used cached consumer found, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else {
+ // If consumer is already cached and is currently not in use, then return that consumer
+ logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer")
+ existingInternalConsumer.inUse = true
+ // Any given TopicPartition should have a consistent key and value type
+ CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]])
+ }
+ }
+
+ private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized {
+ // Clear the consumer from the cache if this is indeed the consumer present in the cache
+ val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition)
+ val cachedInternalConsumer = cache.get(key)
+ if (internalConsumer.eq(cachedInternalConsumer)) {
+ // The released consumer is the same object as the cached one.
+ if (internalConsumer.markedForClose) {
+ internalConsumer.close()
+ cache.remove(key)
+ } else {
+ internalConsumer.inUse = false
+ }
+ } else {
+ // The released consumer is either not the same one as in the cache, or not in the cache
+ // at all. This may happen if the cache was invalidate while this consumer was being used.
+ // Just close this consumer.
+ internalConsumer.close()
+ logInfo(s"Released a supposedly cached consumer that was not found in the cache " +
+ s"$internalConsumer")
+ }
+ }
+}
+
+private[kafka010] object InternalKafkaConsumer {
+ private val UNKNOWN_OFFSET = -2L
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index d9fc9cc206647..3efc90fe466b2 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010
import java.{ util => ju }
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord }
import org.apache.kafka.common.TopicPartition
@@ -55,25 +53,27 @@ private[spark] class KafkaRDD[K, V](
useConsumerCache: Boolean
) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges {
- assert("none" ==
+ require("none" ==
kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String],
ConsumerConfig.AUTO_OFFSET_RESET_CONFIG +
" must be set to none for executor kafka params, else messages may not match offsetRange")
- assert(false ==
+ require(false ==
kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean],
ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG +
" must be set to false for executor kafka params, else offsets may commit before processing")
// TODO is it necessary to have separate configs for initial poll time vs ongoing poll time?
private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms",
- conf.getTimeAsMs("spark.network.timeout", "120s"))
+ conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L)
private val cacheInitialCapacity =
conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16)
private val cacheMaxCapacity =
conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64)
private val cacheLoadFactor =
conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat
+ private val compacted =
+ conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false)
override def persist(newLevel: StorageLevel): this.type = {
logError("Kafka ConsumerRecord is not serializable. " +
@@ -87,48 +87,63 @@ private[spark] class KafkaRDD[K, V](
}.toArray
}
- override def count(): Long = offsetRanges.map(_.count).sum
+ override def count(): Long =
+ if (compacted) {
+ super.count()
+ } else {
+ offsetRanges.map(_.count).sum
+ }
override def countApprox(
timeout: Long,
confidence: Double = 0.95
- ): PartialResult[BoundedDouble] = {
- val c = count
- new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
- }
-
- override def isEmpty(): Boolean = count == 0L
-
- override def take(num: Int): Array[ConsumerRecord[K, V]] = {
- val nonEmptyPartitions = this.partitions
- .map(_.asInstanceOf[KafkaRDDPartition])
- .filter(_.count > 0)
+ ): PartialResult[BoundedDouble] =
+ if (compacted) {
+ super.countApprox(timeout, confidence)
+ } else {
+ val c = count
+ new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
+ }
- if (num < 1 || nonEmptyPartitions.isEmpty) {
- return new Array[ConsumerRecord[K, V]](0)
+ override def isEmpty(): Boolean =
+ if (compacted) {
+ super.isEmpty()
+ } else {
+ count == 0L
}
- // Determine in advance how many messages need to be taken from each partition
- val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
- val remain = num - result.values.sum
- if (remain > 0) {
- val taken = Math.min(remain, part.count)
- result + (part.index -> taken.toInt)
+ override def take(num: Int): Array[ConsumerRecord[K, V]] =
+ if (compacted) {
+ super.take(num)
+ } else if (num < 1) {
+ Array.empty[ConsumerRecord[K, V]]
+ } else {
+ val nonEmptyPartitions = this.partitions
+ .map(_.asInstanceOf[KafkaRDDPartition])
+ .filter(_.count > 0)
+
+ if (nonEmptyPartitions.isEmpty) {
+ Array.empty[ConsumerRecord[K, V]]
} else {
- result
+ // Determine in advance how many messages need to be taken from each partition
+ val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
+ val remain = num - result.values.sum
+ if (remain > 0) {
+ val taken = Math.min(remain, part.count)
+ result + (part.index -> taken.toInt)
+ } else {
+ result
+ }
+ }
+
+ context.runJob(
+ this,
+ (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) =>
+ it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
+ ).flatten
}
}
- val buf = new ArrayBuffer[ConsumerRecord[K, V]]
- val res = context.runJob(
- this,
- (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) =>
- it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
- )
- res.foreach(buf ++= _)
- buf.toArray
- }
-
private def executors(): Array[ExecutorCacheTaskLocation] = {
val bm = sparkContext.env.blockManager
bm.master.getPeers(bm.blockManagerId).toArray
@@ -172,57 +187,130 @@ private[spark] class KafkaRDD[K, V](
override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = {
val part = thePart.asInstanceOf[KafkaRDDPartition]
- assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part))
+ require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part))
if (part.fromOffset == part.untilOffset) {
logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " +
s"skipping ${part.topic} ${part.partition}")
Iterator.empty
} else {
- new KafkaRDDIterator(part, context)
+ logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " +
+ s"offsets ${part.fromOffset} -> ${part.untilOffset}")
+ if (compacted) {
+ new CompactedKafkaRDDIterator[K, V](
+ part,
+ context,
+ kafkaParams,
+ useConsumerCache,
+ pollTimeout,
+ cacheInitialCapacity,
+ cacheMaxCapacity,
+ cacheLoadFactor
+ )
+ } else {
+ new KafkaRDDIterator[K, V](
+ part,
+ context,
+ kafkaParams,
+ useConsumerCache,
+ pollTimeout,
+ cacheInitialCapacity,
+ cacheMaxCapacity,
+ cacheLoadFactor
+ )
+ }
}
}
+}
- /**
- * An iterator that fetches messages directly from Kafka for the offsets in partition.
- * Uses a cached consumer where possible to take advantage of prefetching
- */
- private class KafkaRDDIterator(
- part: KafkaRDDPartition,
- context: TaskContext) extends Iterator[ConsumerRecord[K, V]] {
-
- logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " +
- s"offsets ${part.fromOffset} -> ${part.untilOffset}")
-
- val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+/**
+ * An iterator that fetches messages directly from Kafka for the offsets in partition.
+ * Uses a cached consumer where possible to take advantage of prefetching
+ */
+private class KafkaRDDIterator[K, V](
+ part: KafkaRDDPartition,
+ context: TaskContext,
+ kafkaParams: ju.Map[String, Object],
+ useConsumerCache: Boolean,
+ pollTimeout: Long,
+ cacheInitialCapacity: Int,
+ cacheMaxCapacity: Int,
+ cacheLoadFactor: Float
+) extends Iterator[ConsumerRecord[K, V]] {
+
+ context.addTaskCompletionListener(_ => closeIfNeeded())
+
+ val consumer = {
+ KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
+ KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache)
+ }
- context.addTaskCompletionListener{ context => closeIfNeeded() }
+ var requestOffset = part.fromOffset
- val consumer = if (useConsumerCache) {
- CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
- if (context.attemptNumber >= 1) {
- // just in case the prior attempt failures were cache related
- CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
- }
- CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams)
- } else {
- CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams)
+ def closeIfNeeded(): Unit = {
+ if (consumer != null) {
+ consumer.release()
}
+ }
- var requestOffset = part.fromOffset
+ override def hasNext(): Boolean = requestOffset < part.untilOffset
- def closeIfNeeded(): Unit = {
- if (!useConsumerCache && consumer != null) {
- consumer.close
- }
+ override def next(): ConsumerRecord[K, V] = {
+ if (!hasNext) {
+ throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached")
}
+ val r = consumer.get(requestOffset, pollTimeout)
+ requestOffset += 1
+ r
+ }
+}
- override def hasNext(): Boolean = requestOffset < part.untilOffset
-
- override def next(): ConsumerRecord[K, V] = {
- assert(hasNext(), "Can't call getNext() once untilOffset has been reached")
- val r = consumer.get(requestOffset, pollTimeout)
- requestOffset += 1
- r
+/**
+ * An iterator that fetches messages directly from Kafka for the offsets in partition.
+ * Uses a cached consumer where possible to take advantage of prefetching.
+ * Intended for compacted topics, or other cases when non-consecutive offsets are ok.
+ */
+private class CompactedKafkaRDDIterator[K, V](
+ part: KafkaRDDPartition,
+ context: TaskContext,
+ kafkaParams: ju.Map[String, Object],
+ useConsumerCache: Boolean,
+ pollTimeout: Long,
+ cacheInitialCapacity: Int,
+ cacheMaxCapacity: Int,
+ cacheLoadFactor: Float
+ ) extends KafkaRDDIterator[K, V](
+ part,
+ context,
+ kafkaParams,
+ useConsumerCache,
+ pollTimeout,
+ cacheInitialCapacity,
+ cacheMaxCapacity,
+ cacheLoadFactor
+ ) {
+
+ consumer.compactedStart(part.fromOffset, pollTimeout)
+
+ private var nextRecord = consumer.compactedNext(pollTimeout)
+
+ private var okNext: Boolean = true
+
+ override def hasNext(): Boolean = okNext
+
+ override def next(): ConsumerRecord[K, V] = {
+ if (!hasNext) {
+ throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached")
+ }
+ val r = nextRecord
+ if (r.offset + 1 >= part.untilOffset) {
+ okNext = false
+ } else {
+ nextRecord = consumer.compactedNext(pollTimeout)
+ if (nextRecord.offset >= part.untilOffset) {
+ okNext = false
+ consumer.compactedPrevious()
+ }
}
+ r
}
}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
index 453b5e5ab20d3..35e4678f2e3c8 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010
import java.io.File
import java.lang.{ Long => JLong }
-import java.util.{ Arrays, HashMap => JHashMap, Map => JMap }
+import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID }
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicLong
@@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
@@ -617,6 +617,101 @@ class DirectKafkaStreamSuite
ssc.stop()
}
+ test("backpressure.initialRate should honor maxRatePerPartition") {
+ backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250)
+ }
+
+ test("use backpressure.initialRate with backpressure") {
+ backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150)
+ }
+
+ private def backpressureTest(
+ maxRatePerPartition: Int,
+ initialRate: Int,
+ maxMessagesPerPartition: Int) = {
+
+ val topic = UUID.randomUUID().toString
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.backpressure.enabled", "true")
+ .set("spark.streaming.backpressure.initialRate", initialRate.toString)
+ .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString)
+
+ val messages = Map("foo" -> 5000)
+ kafkaTestUtils.sendMessages(topic, messages)
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(500))
+
+ val kafkaStream = withClue("Error creating direct stream") {
+ new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala),
+ new DefaultPerPartitionConfig(sparkConf)
+ )
+ }
+ kafkaStream.start()
+
+ val input = Map(new TopicPartition(topic, 0) -> 1000L)
+
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second
+
+ kafkaStream.stop()
+ }
+
+ test("maxMessagesPerPartition with zero offset and rate equal to one") {
+ val topic = "backpressure"
+ val kafkaParams = getKafkaParams()
+ val batchIntervalMilliseconds = 60000
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+ val estimateRate = 1L
+ val fromOffsets = Map(
+ new TopicPartition(topic, 0) -> 0L,
+ new TopicPartition(topic, 1) -> 0L,
+ new TopicPartition(topic, 2) -> 0L,
+ new TopicPartition(topic, 3) -> 0L
+ )
+ val kafkaStream = withClue("Error creating direct stream") {
+ new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala),
+ new DefaultPerPartitionConfig(sparkConf)
+ ) {
+ currentOffsets = fromOffsets
+ override val rateController = Some(new ConstantRateController(id, null, estimateRate))
+ }
+ }
+
+ val offsets = Map[TopicPartition, Long](
+ new TopicPartition(topic, 0) -> 0,
+ new TopicPartition(topic, 1) -> 100L,
+ new TopicPartition(topic, 2) -> 200L,
+ new TopicPartition(topic, 3) -> 300L
+ )
+ val result = kafkaStream.maxMessagesPerPartition(offsets)
+ val expected = Map(
+ new TopicPartition(topic, 0) -> 1L,
+ new TopicPartition(topic, 1) -> 10L,
+ new TopicPartition(topic, 2) -> 20L,
+ new TopicPartition(topic, 3) -> 30L
+ )
+ assert(result.contains(expected), s"Number of messages per partition must be at least 1")
+ }
+
/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = {
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000000000..d934c64962adb
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+
+class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll {
+ private var testUtils: KafkaTestUtils = _
+ private val topic = "topic" + Random.nextInt()
+ private val topicPartition = new TopicPartition(topic, 0)
+ private val groupId = "groupId"
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ testUtils = new KafkaTestUtils
+ testUtils.setup()
+ KafkaDataConsumer.init(16, 64, 0.75f)
+ }
+
+ override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.teardown()
+ testUtils = null
+ }
+ super.afterAll()
+ }
+
+ private def getKafkaParams() = Map[String, Object](
+ GROUP_ID_CONFIG -> groupId,
+ BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+ KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ AUTO_OFFSET_RESET_CONFIG -> "earliest",
+ ENABLE_AUTO_COMMIT_CONFIG -> "false"
+ ).asJava
+
+ test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") {
+ KafkaDataConsumer.cache.clear()
+
+ val kafkaParams = getKafkaParams()
+
+ val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, null, true)
+ consumer1.release()
+
+ val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, null, true)
+ consumer2.release()
+
+ assert(KafkaDataConsumer.cache.size() == 1)
+ val key = new CacheKey(groupId, topicPartition)
+ val existingInternalConsumer = KafkaDataConsumer.cache.get(key)
+ assert(existingInternalConsumer.eq(consumer1.internalConsumer))
+ assert(existingInternalConsumer.eq(consumer2.internalConsumer))
+ }
+
+ test("concurrent use of KafkaDataConsumer") {
+ val data = (1 to 1000).map(_.toString)
+ testUtils.createTopic(topic)
+ testUtils.sendMessages(topic, data.toArray)
+
+ val kafkaParams = getKafkaParams()
+
+ val numThreads = 100
+ val numConsumerUsages = 500
+
+ @volatile var error: Throwable = null
+
+ def consume(i: Int): Unit = {
+ val useCache = Random.nextBoolean
+ val taskContext = if (Random.nextBoolean) {
+ new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
+ } else {
+ null
+ }
+ val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, taskContext, useCache)
+ try {
+ val rcvd = (0 until data.length).map { offset =>
+ val bytes = consumer.get(offset, 10000).value()
+ new String(bytes)
+ }
+ assert(rcvd == data)
+ } catch {
+ case e: Throwable =>
+ error = e
+ throw e
+ } finally {
+ consumer.release()
+ }
+ }
+
+ val threadPool = Executors.newFixedThreadPool(numThreads)
+ try {
+ val futures = (1 to numConsumerUsages).map { i =>
+ threadPool.submit(new Runnable {
+ override def run(): Unit = { consume(i) }
+ })
+ }
+ futures.foreach(_.get(1, TimeUnit.MINUTES))
+ assert(error == null)
+ } finally {
+ threadPool.shutdown()
+ }
+ }
+}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
index be373af0599cc..271adea1df731 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
@@ -18,16 +18,22 @@
package org.apache.spark.streaming.kafka010
import java.{ util => ju }
+import java.io.File
import scala.collection.JavaConverters._
import scala.util.Random
+import kafka.common.TopicAndPartition
+import kafka.log._
+import kafka.message._
+import kafka.utils.Pool
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.serialization.StringDeserializer
import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.streaming.kafka010.mocks.MockTime
class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
private val preferredHosts = LocationStrategies.PreferConsistent
+ private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) {
+ val mockTime = new MockTime()
+ // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api
+ val logs = new Pool[TopicAndPartition, Log]()
+ val logDir = kafkaTestUtils.brokerLogDir
+ val dir = new File(logDir, topic + "-" + partition)
+ dir.mkdirs()
+ val logProps = new ju.Properties()
+ logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact)
+ logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f))
+ val log = new Log(
+ dir,
+ LogConfig(logProps),
+ 0L,
+ mockTime.scheduler,
+ mockTime
+ )
+ messages.foreach { case (k, v) =>
+ val msg = new ByteBufferMessageSet(
+ NoCompressionCodec,
+ new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue))
+ log.append(msg)
+ }
+ log.roll()
+ logs.put(TopicAndPartition(topic, partition), log)
+
+ val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs)
+ cleaner.startup()
+ cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000)
+
+ cleaner.shutdown()
+ mockTime.scheduler.shutdown()
+ }
+
+
test("basic usage") {
val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}"
kafkaTestUtils.createTopic(topic)
@@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}
+ test("compacted topic") {
+ val compactConf = sparkConf.clone()
+ compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true")
+ sc.stop()
+ sc = new SparkContext(compactConf)
+ val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}"
+
+ val messages = Array(
+ ("a", "1"),
+ ("a", "2"),
+ ("b", "1"),
+ ("c", "1"),
+ ("c", "2"),
+ ("b", "2"),
+ ("b", "3")
+ )
+ val compactedMessages = Array(
+ ("a", "2"),
+ ("b", "3"),
+ ("c", "2")
+ )
+
+ compactLogs(topic, 0, messages)
+
+ val props = new ju.Properties()
+ props.put("cleanup.policy", "compact")
+ props.put("flush.messages", "1")
+ props.put("segment.ms", "1")
+ props.put("segment.bytes", "256")
+ kafkaTestUtils.createTopic(topic, 1, props)
+
+
+ val kafkaParams = getKafkaParams()
+
+ val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
+
+ val rdd = KafkaUtils.createRDD[String, String](
+ sc, kafkaParams, offsetRanges, preferredHosts
+ ).map(m => m.key -> m.value)
+
+ val received = rdd.collect.toSet
+ assert(received === compactedMessages.toSet)
+
+ // size-related method optimizations return sane results
+ assert(rdd.count === compactedMessages.size)
+ assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size)
+ assert(!rdd.isEmpty)
+ assert(rdd.take(1).size === 1)
+ assert(rdd.take(1).head === compactedMessages.head)
+ assert(rdd.take(messages.size + 10).size === compactedMessages.size)
+
+ val emptyRdd = KafkaUtils.createRDD[String, String](
+ sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts)
+
+ assert(emptyRdd.isEmpty)
+
+ // invalid offset ranges throw exceptions
+ val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1))
+ intercept[SparkException] {
+ val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts)
+ .map(_.value)
+ .collect()
+ }
+ }
+
test("iterator boundary conditions") {
// the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}"
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
index 6c7024ea4b5a5..70b579d96d692 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
@@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging {
}
/** Create a Kafka topic and wait until it is propagated to the whole cluster */
- def createTopic(topic: String, partitions: Int): Unit = {
- AdminUtils.createTopic(zkUtils, topic, partitions, 1)
+ def createTopic(topic: String, partitions: Int, config: Properties): Unit = {
+ AdminUtils.createTopic(zkUtils, topic, partitions, 1, config)
// wait until metadata is propagated
(0 until partitions).foreach { p =>
waitUntilMetadataIsPropagated(topic, p)
}
}
+ /** Create a Kafka topic and wait until it is propagated to the whole cluster */
+ def createTopic(topic: String, partitions: Int): Unit = {
+ createTopic(topic, partitions, new Properties())
+ }
+
/** Create a Kafka topic and wait until it is propagated to the whole cluster */
def createTopic(topic: String): Unit = {
- createTopic(topic, 1)
+ createTopic(topic, 1, new Properties())
}
/** Java-friendly function for sending messages to the Kafka broker */
@@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging {
producer = null
}
+ /** Send the array of (key, value) messages to the Kafka broker */
+ def sendMessages(topic: String, messages: Array[(String, String)]): Unit = {
+ producer = new KafkaProducer[String, String](producerConfiguration)
+ messages.foreach { message =>
+ producer.send(new ProducerRecord[String, String](topic, message._1, message._2))
+ }
+ producer.close()
+ producer = null
+ }
+
+ val brokerLogDir = Utils.createTempDir().getAbsolutePath
+
private def brokerConfiguration: Properties = {
val props = new Properties()
props.put("broker.id", "0")
props.put("host.name", "localhost")
props.put("port", brokerPort.toString)
- props.put("log.dir", Utils.createTempDir().getAbsolutePath)
+ props.put("log.dir", brokerLogDir)
props.put("zookeeper.connect", zkAddress)
props.put("log.flush.interval.messages", "1")
props.put("replica.socket.timeout.ms", "1500")
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala
new file mode 100644
index 0000000000000..928e1a6ef54b9
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010.mocks
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable.PriorityQueue
+
+import kafka.utils.{Scheduler, Time}
+
+/**
+ * A mock scheduler that executes tasks synchronously using a mock time instance.
+ * Tasks are executed synchronously when the time is advanced.
+ * This class is meant to be used in conjunction with MockTime.
+ *
+ * Example usage
+ *
+ * val time = new MockTime
+ * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000)
+ * time.sleep(1001) // this should cause our scheduled task to fire
+ *
+ *
+ * Incrementing the time to the exact next execution time of a task will result in that task
+ * executing (it as if execution itself takes no time).
+ */
+private[kafka010] class MockScheduler(val time: Time) extends Scheduler {
+
+ /* a priority queue of tasks ordered by next execution time */
+ var tasks = new PriorityQueue[MockTask]()
+
+ def isStarted: Boolean = true
+
+ def startup(): Unit = {}
+
+ def shutdown(): Unit = synchronized {
+ tasks.foreach(_.fun())
+ tasks.clear()
+ }
+
+ /**
+ * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs
+ * when this method is called and the execution happens synchronously in the calling thread.
+ * If you are using the scheduler associated with a MockTime instance this call
+ * will be triggered automatically.
+ */
+ def tick(): Unit = synchronized {
+ val now = time.milliseconds
+ while(!tasks.isEmpty && tasks.head.nextExecution <= now) {
+ /* pop and execute the task with the lowest next execution time */
+ val curr = tasks.dequeue
+ curr.fun()
+ /* if the task is periodic, reschedule it and re-enqueue */
+ if(curr.periodic) {
+ curr.nextExecution += curr.period
+ this.tasks += curr
+ }
+ }
+ }
+
+ def schedule(
+ name: String,
+ fun: () => Unit,
+ delay: Long = 0,
+ period: Long = -1,
+ unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized {
+ tasks += MockTask(name, fun, time.milliseconds + delay, period = period)
+ tick()
+ }
+
+}
+
+case class MockTask(
+ val name: String,
+ val fun: () => Unit,
+ var nextExecution: Long,
+ val period: Long) extends Ordered[MockTask] {
+ def periodic: Boolean = period >= 0
+ def compare(t: MockTask): Int = {
+ java.lang.Long.compare(t.nextExecution, nextExecution)
+ }
+}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala
new file mode 100644
index 0000000000000..a68f94db1f689
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010.mocks
+
+import java.util.concurrent._
+
+import kafka.utils.Time
+
+/**
+ * A class used for unit testing things which depend on the Time interface.
+ *
+ * This class never manually advances the clock, it only does so when you call
+ * sleep(ms)
+ *
+ * It also comes with an associated scheduler instance for managing background tasks in
+ * a deterministic way.
+ */
+private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time {
+
+ val scheduler = new MockScheduler(this)
+
+ def this() = this(System.currentTimeMillis)
+
+ def milliseconds: Long = currentMs
+
+ def nanoseconds: Long =
+ TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS)
+
+ def sleep(ms: Long) {
+ this.currentMs += ms
+ scheduler.tick()
+ }
+
+ override def toString(): String = s"MockTime($milliseconds)"
+
+}
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index d52c230eb7849..9297c39d170c4 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -91,9 +91,16 @@ class DirectKafkaInputDStream[
private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong(
"spark.streaming.kafka.maxRatePerPartition", 0)
+ private val initialRate = context.sparkContext.getConf.getLong(
+ "spark.streaming.backpressure.initialRate", 0)
+
protected[streaming] def maxMessagesPerPartition(
offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = {
- val estimatedRateLimit = rateController.map(_.getLatestRate())
+
+ val estimatedRateLimit = rateController.map { x => {
+ val lr = x.getLatestRate()
+ if (lr > 0) lr else initialRate
+ }}
// calculate a per-partition rate limit based on current lag
val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
@@ -104,17 +111,17 @@ class DirectKafkaInputDStream[
val totalLag = lagPerPartition.values.sum
lagPerPartition.map { case (tp, lag) =>
- val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+ val backpressureRate = lag / totalLag.toDouble * rate
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
- case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+ case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble }
}
if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
Some(effectiveRateLimitPerPartition.map {
- case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+ case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L)
})
} else {
None
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index 5ea52b6ad36a0..791cf0efaf888 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -191,6 +191,7 @@ class KafkaRDD[
private def fetchBatch: Iterator[MessageAndOffset] = {
val req = new FetchRequestBuilder()
+ .clientId(consumer.clientId)
.addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes)
.build()
val resp = consumer.fetch(req)
diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 06ef5bc3f8bd0..ecca38784e777 100644
--- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.kafka
import java.io.File
-import java.util.Arrays
+import java.util.{ Arrays, UUID }
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicLong
@@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.Eventually
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
import org.apache.spark.streaming.scheduler._
import org.apache.spark.streaming.scheduler.rate.RateEstimator
import org.apache.spark.util.Utils
@@ -456,6 +455,111 @@ class DirectKafkaStreamSuite
ssc.stop()
}
+ test("use backpressure.initialRate with backpressure") {
+ backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250)
+ }
+
+ test("backpressure.initialRate should honor maxRatePerPartition") {
+ backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150)
+ }
+
+ private def backpressureTest(
+ maxRatePerPartition: Int,
+ initialRate: Int,
+ maxMessagesPerPartition: Int) = {
+
+ val topic = UUID.randomUUID().toString
+ val topicPartitions = Set(TopicAndPartition(topic, 0))
+ kafkaTestUtils.createTopic(topic, 1)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
+ "auto.offset.reset" -> "smallest"
+ )
+
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.backpressure.enabled", "true")
+ .set("spark.streaming.backpressure.initialRate", initialRate.toString)
+ .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString)
+
+ val messages = Map("foo" -> 5000)
+ kafkaTestUtils.sendMessages(topic, messages)
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(500))
+
+ val kafkaStream = withClue("Error creating direct stream") {
+ val kc = new KafkaCluster(kafkaParams)
+ val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+ val m = kc.getEarliestLeaderOffsets(topicPartitions)
+ .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset))
+
+ new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+ ssc, kafkaParams, m, messageHandler)
+ }
+ kafkaStream.start()
+
+ val input = Map(new TopicAndPartition(topic, 0) -> 1000L)
+
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition))
+
+ kafkaStream.stop()
+ }
+
+ test("maxMessagesPerPartition with zero offset and rate equal to one") {
+ val topic = "backpressure"
+ val kafkaParams = Map(
+ "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
+ "auto.offset.reset" -> "smallest"
+ )
+
+ val batchIntervalMilliseconds = 60000
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+ val estimatedRate = 1L
+ val kafkaStream = withClue("Error creating direct stream") {
+ val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+ val fromOffsets = Map(
+ TopicAndPartition(topic, 0) -> 0L,
+ TopicAndPartition(topic, 1) -> 0L,
+ TopicAndPartition(topic, 2) -> 0L,
+ TopicAndPartition(topic, 3) -> 0L
+ )
+ new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+ ssc, kafkaParams, fromOffsets, messageHandler) {
+ override protected[streaming] val rateController =
+ Some(new DirectKafkaRateController(id, null) {
+ override def getLatestRate() = estimatedRate
+ })
+ }
+ }
+
+ val offsets = Map(
+ TopicAndPartition(topic, 0) -> 0L,
+ TopicAndPartition(topic, 1) -> 100L,
+ TopicAndPartition(topic, 2) -> 200L,
+ TopicAndPartition(topic, 3) -> 300L
+ )
+ val result = kafkaStream.maxMessagesPerPartition(offsets)
+ val expected = Map(
+ TopicAndPartition(topic, 0) -> 1L,
+ TopicAndPartition(topic, 1) -> 10L,
+ TopicAndPartition(topic, 2) -> 20L,
+ TopicAndPartition(topic, 3) -> 30L
+ )
+ assert(result.contains(expected), s"Number of messages per partition must be at least 1")
+ }
+
/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py
index 4d7fc9a549bfb..49794faab88c4 100644
--- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py
+++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py
@@ -34,7 +34,7 @@
$ export AWS_SECRET_KEY=
# run the example
- $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\
+ $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\
spark-streaming-kinesis-asl-assembly_*.jar \
external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \
myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com
diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml
index 8e424b1c50236..2c39a7df0146e 100644
--- a/hadoop-cloud/pom.xml
+++ b/hadoop-cloud/pom.xml
@@ -38,7 +38,32 @@
hadoop-cloud
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.hadoop
+ hadoop-client
+ ${hadoop.version}
+ provided
+
+
+
+ hadoop-3.1
+
+
+
+ org.apache.hadoop
+ hadoop-cloud-storage
+ ${hadoop.version}
+ ${hadoop.deps.scope}
+
+
+ org.apache.hadoop
+ hadoop-common
+
+
+ org.codehaus.jackson
+ jackson-mapper-asl
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+ com.google.guava
+ guava
+
+
+
+
+
+ org.eclipse.jetty
+ jetty-util
+ ${hadoop.deps.scope}
+
+
+ org.eclipse.jetty
+ jetty-util-ajax
+ ${jetty.version}
+ ${hadoop.deps.scope}
+
+
+
+
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java
index 44e69fc45dffa..4e02843480e8f 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java
@@ -139,7 +139,7 @@ public T setMainClass(String mainClass) {
public T addSparkArg(String arg) {
SparkSubmitOptionParser validator = new ArgumentValidator(false);
validator.parse(Arrays.asList(arg));
- builder.sparkArgs.add(arg);
+ builder.userArgs.add(arg);
return self();
}
@@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) {
}
} else {
validator.parse(Arrays.asList(name, value));
- builder.sparkArgs.add(name);
- builder.sparkArgs.add(value);
+ builder.userArgs.add(name);
+ builder.userArgs.add(value);
}
return self();
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
index 4b740d3fad20e..15fbca0facef2 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
@@ -25,7 +25,7 @@
class InProcessAppHandle extends AbstractAppHandle {
private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'";
- private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName());
+ private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName());
private static final AtomicLong THREAD_IDS = new AtomicLong();
// Avoid really long thread names.
diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java
index 6d726b4a69a86..688e1f763c205 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java
@@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException {
}
Class> sparkSubmit;
+ // SPARK-22941: first try the new SparkSubmit interface that has better error handling,
+ // but fall back to the old interface in case someone is mixing & matching launcher and
+ // Spark versions.
try {
- sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit");
- } catch (Exception e) {
- throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e);
+ sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit");
+ } catch (Exception e1) {
+ try {
+ sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit");
+ } catch (Exception e2) {
+ throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.",
+ e2);
+ }
}
Method main;
diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java
index 1e34bb8c73279..d967aa39a4827 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/Main.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java
@@ -17,6 +17,7 @@
package org.apache.spark.launcher;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception {
String className = args.remove(0);
boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND"));
- AbstractCommandBuilder builder;
+ Map env = new HashMap<>();
+ List cmd;
if (className.equals("org.apache.spark.deploy.SparkSubmit")) {
try {
- builder = new SparkSubmitCommandBuilder(args);
+ AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args);
+ cmd = buildCommand(builder, env, printLaunchCommand);
} catch (IllegalArgumentException e) {
printLaunchCommand = false;
System.err.println("Error: " + e.getMessage());
@@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception {
help.add(parser.className);
}
help.add(parser.USAGE_ERROR);
- builder = new SparkSubmitCommandBuilder(help);
+ AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help);
+ cmd = buildCommand(builder, env, printLaunchCommand);
}
} else {
- builder = new SparkClassCommandBuilder(className, args);
- }
-
- Map env = new HashMap<>();
- List cmd = builder.buildCommand(env);
- if (printLaunchCommand) {
- System.err.println("Spark Command: " + join(" ", cmd));
- System.err.println("========================================");
+ AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args);
+ cmd = buildCommand(builder, env, printLaunchCommand);
}
if (isWindows()) {
@@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception {
}
}
+ /**
+ * Prepare spark commands with the appropriate command builder.
+ * If printLaunchCommand is set then the commands will be printed to the stderr.
+ */
+ private static List buildCommand(
+ AbstractCommandBuilder builder,
+ Map env,
+ boolean printLaunchCommand) throws IOException, IllegalArgumentException {
+ List cmd = builder.buildCommand(env);
+ if (printLaunchCommand) {
+ System.err.println("Spark Command: " + join(" ", cmd));
+ System.err.println("========================================");
+ }
+ return cmd;
+ }
+
/**
* Prepare a command line for execution from a Windows batch script.
*
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index e0ef22d7d5058..cc65f78b45c30 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -88,8 +88,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
SparkLauncher.NO_RESOURCE);
}
- final List sparkArgs;
- private final boolean isAppResourceReq;
+ final List userArgs;
+ private final List parsedArgs;
+ // Special command means no appResource and no mainClass required
+ private final boolean isSpecialCommand;
private final boolean isExample;
/**
@@ -99,17 +101,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
*/
private boolean allowsMixedArguments;
+ /**
+ * This constructor is used when creating a user-configurable launcher. It allows the
+ * spark-submit argument list to be modified after creation.
+ */
SparkSubmitCommandBuilder() {
- this.sparkArgs = new ArrayList<>();
- this.isAppResourceReq = true;
+ this.isSpecialCommand = false;
this.isExample = false;
+ this.parsedArgs = new ArrayList<>();
+ this.userArgs = new ArrayList<>();
}
+ /**
+ * This constructor is used when invoking spark-submit; it parses and validates arguments
+ * provided by the user on the command line.
+ */
SparkSubmitCommandBuilder(List args) {
this.allowsMixedArguments = false;
- this.sparkArgs = new ArrayList<>();
+ this.parsedArgs = new ArrayList<>();
boolean isExample = false;
List submitArgs = args;
+ this.userArgs = Collections.emptyList();
if (args.size() > 0) {
switch (args.get(0)) {
@@ -127,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
case RUN_EXAMPLE:
isExample = true;
+ appResource = SparkLauncher.NO_RESOURCE;
submitArgs = args.subList(1, args.size());
}
this.isExample = isExample;
- OptionParser parser = new OptionParser();
+ OptionParser parser = new OptionParser(true);
parser.parse(submitArgs);
- this.isAppResourceReq = parser.isAppResourceReq;
- } else {
+ this.isSpecialCommand = parser.isSpecialCommand;
+ } else {
this.isExample = isExample;
- this.isAppResourceReq = false;
+ this.isSpecialCommand = true;
}
}
@Override
public List buildCommand(Map env)
throws IOException, IllegalArgumentException {
- if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) {
+ if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) {
return buildPySparkShellCommand(env);
- } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) {
+ } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) {
return buildSparkRCommand(env);
} else {
return buildSparkSubmitCommand(env);
@@ -154,9 +167,19 @@ public List buildCommand(Map env)
List buildSparkSubmitArgs() {
List args = new ArrayList<>();
- SparkSubmitOptionParser parser = new SparkSubmitOptionParser();
+ OptionParser parser = new OptionParser(false);
+ final boolean isSpecialCommand;
+
+ // If the user args array is not empty, we need to parse it to detect exactly what
+ // the user is trying to run, so that checks below are correct.
+ if (!userArgs.isEmpty()) {
+ parser.parse(userArgs);
+ isSpecialCommand = parser.isSpecialCommand;
+ } else {
+ isSpecialCommand = this.isSpecialCommand;
+ }
- if (!allowsMixedArguments && isAppResourceReq) {
+ if (!allowsMixedArguments && !isSpecialCommand) {
checkArgument(appResource != null, "Missing application resource.");
}
@@ -208,15 +231,16 @@ List buildSparkSubmitArgs() {
args.add(join(",", pyFiles));
}
- if (isAppResourceReq) {
- checkArgument(!isExample || mainClass != null, "Missing example class name.");
+ if (isExample && !isSpecialCommand) {
+ checkArgument(mainClass != null, "Missing example class name.");
}
+
if (mainClass != null) {
args.add(parser.CLASS);
args.add(mainClass);
}
- args.addAll(sparkArgs);
+ args.addAll(parsedArgs);
if (appResource != null) {
args.add(appResource);
}
@@ -399,7 +423,12 @@ private List findExamplesJars() {
private class OptionParser extends SparkSubmitOptionParser {
- boolean isAppResourceReq = true;
+ boolean isSpecialCommand = false;
+ private final boolean errorOnUnknownArgs;
+
+ OptionParser(boolean errorOnUnknownArgs) {
+ this.errorOnUnknownArgs = errorOnUnknownArgs;
+ }
@Override
protected boolean handle(String opt, String value) {
@@ -443,23 +472,20 @@ protected boolean handle(String opt, String value) {
break;
case KILL_SUBMISSION:
case STATUS:
- isAppResourceReq = false;
- sparkArgs.add(opt);
- sparkArgs.add(value);
+ isSpecialCommand = true;
+ parsedArgs.add(opt);
+ parsedArgs.add(value);
break;
case HELP:
case USAGE_ERROR:
- isAppResourceReq = false;
- sparkArgs.add(opt);
- break;
case VERSION:
- isAppResourceReq = false;
- sparkArgs.add(opt);
+ isSpecialCommand = true;
+ parsedArgs.add(opt);
break;
default:
- sparkArgs.add(opt);
+ parsedArgs.add(opt);
if (value != null) {
- sparkArgs.add(value);
+ parsedArgs.add(value);
}
break;
}
@@ -483,12 +509,13 @@ protected boolean handleUnknown(String opt) {
mainClass = className;
appResource = SparkLauncher.NO_RESOURCE;
return false;
- } else {
+ } else if (errorOnUnknownArgs) {
checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt);
checkState(appResource == null, "Found unrecognized argument but resource is already set.");
appResource = opt;
return false;
}
+ return true;
}
@Override
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index d16337a319be3..f8dc0ec7a0bf6 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -185,6 +185,34 @@ public void testStreamFiltering() throws Exception {
}
}
+ @Test
+ public void testAppHandleDisconnect() throws Exception {
+ LauncherServer server = LauncherServer.getOrCreateServer();
+ ChildProcAppHandle handle = new ChildProcAppHandle(server);
+ String secret = server.registerHandle(handle);
+
+ TestClient client = null;
+ try {
+ Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort());
+ client = new TestClient(s);
+ client.send(new Hello(secret, "1.4.0"));
+ client.send(new SetAppId("someId"));
+
+ // Wait until we know the server has received the messages and matched the handle to the
+ // connection before disconnecting.
+ eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
+ assertEquals("someId", handle.getAppId());
+ });
+
+ handle.disconnect();
+ waitForError(client, secret);
+ } finally {
+ handle.kill();
+ close(client);
+ client.clientThread.join();
+ }
+ }
+
private void close(Closeable c) {
if (c != null) {
try {
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
index 2e050f8413074..b343094b2e7b8 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.launcher;
import java.io.File;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
@@ -27,7 +28,10 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
import static org.junit.Assert.*;
public class SparkSubmitCommandBuilderSuite extends BaseSuite {
@@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite {
private static File dummyPropsFile;
private static SparkSubmitOptionParser parser;
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
@BeforeClass
public static void setUp() throws Exception {
dummyPropsFile = File.createTempFile("spark", "properties");
@@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception {
@Test
public void testCliKillAndStatus() throws Exception {
- testCLIOpts(parser.STATUS);
- testCLIOpts(parser.KILL_SUBMISSION);
+ List params = Arrays.asList("driver-20160531171222-0000");
+ testCLIOpts(null, parser.STATUS, params);
+ testCLIOpts(null, parser.KILL_SUBMISSION, params);
+ testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params);
+ testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params);
}
@Test
@@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception {
env.get("SPARKR_SUBMIT_ARGS"));
}
+ @Test(expected = IllegalArgumentException.class)
+ public void testExamplesRunnerNoArg() throws Exception {
+ List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE);
+ Map env = new HashMap<>();
+ buildCommand(sparkSubmitArgs, env);
+ }
+
+ @Test
+ public void testExamplesRunnerNoMainClass() throws Exception {
+ testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null);
+ testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null);
+ testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null);
+ }
+
+ @Test
+ public void testExamplesRunnerWithMasterNoMainClass() throws Exception {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("Missing example class name.");
+
+ List sparkSubmitArgs = Arrays.asList(
+ SparkSubmitCommandBuilder.RUN_EXAMPLE,
+ parser.MASTER + "=foo"
+ );
+ Map env = new HashMap<>();
+ buildCommand(sparkSubmitArgs, env);
+ }
+
@Test
public void testExamplesRunner() throws Exception {
List sparkSubmitArgs = Arrays.asList(
@@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th
return newCommandBuilder(args).buildCommand(env);
}
- private void testCLIOpts(String opt) throws Exception {
- List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000");
+ private void testCLIOpts(String appResource, String opt, List params) throws Exception {
+ List args = new ArrayList<>();
+ if (appResource != null) {
+ args.add(appResource);
+ }
+ args.add(opt);
+ if (params != null) {
+ args.addAll(params);
+ }
Map env = new HashMap<>();
- List cmd = buildCommand(helpArgs, env);
+ List cmd = buildCommand(args, env);
assertTrue(opt + " should be contained in the final cmd.",
cmd.contains(opt));
}
diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-jmock.txt
new file mode 100644
index 0000000000000..ed7964fe3d9ef
--- /dev/null
+++ b/licenses/LICENSE-jmock.txt
@@ -0,0 +1,28 @@
+Copyright (c) 2000-2017, jMock.org
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer. Redistributions
+in binary form must reproduce the above copyright notice, this list of
+conditions and the following disclaimer in the documentation and/or
+other materials provided with the distribution.
+
+Neither the name of jMock nor the names of its contributors may be
+used to endorse or promote products derived from this software without
+specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
new file mode 100644
index 0000000000000..f14431d50feec
--- /dev/null
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -0,0 +1,4 @@
+org.apache.spark.ml.regression.InternalLinearRegressionModelWriter
+org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter
+org.apache.spark.ml.clustering.InternalKMeansModelWriter
+org.apache.spark.ml.clustering.PMMLKMeansModelWriter
\ No newline at end of file
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 08b0cb9b8f6a5..d8f3dfa874439 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
/**
* Predict label for the given features.
- * This internal method is used to implement `transform()` and output [[predictionCol]].
+ * This method is used to implement `transform()` and output [[predictionCol]].
*/
- protected def predict(features: FeaturesType): Double
+ @Since("2.4.0")
+ def predict(features: FeaturesType): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 9d1d5aa1e0cff..7e5790ab70ee9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
/**
* Predict label for the given features.
- * This internal method is used to implement `transform()` and output [[predictionCol]].
+ * This method is used to implement `transform()` and output [[predictionCol]].
*
* This default implementation for classification predicts the index of the maximum value
* from `predictRaw()`.
*/
- override protected def predict(features: FeaturesType): Double = {
+ override def predict(features: FeaturesType): Double = {
raw2prediction(predictRaw(features))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f0896ec52..c9786f1f7ceb1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setSeed(value: Long): this.type = set(seed, value)
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val instr = Instrumentation.create(this, oldDataset)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
@@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") (
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
- instr.logParams(params: _*)
+ instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
@@ -165,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
@Since("1.4.0")
class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
- @Since("1.4.0")override val rootNode: Node,
+ @Since("1.4.0")override val rootNode: ClassificationNode,
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
@@ -178,10 +181,10 @@ class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+ private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}
@@ -276,9 +279,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
- val root = loadTreeNodes(path, metadata, sparkSession)
- val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
+ val model = new DecisionTreeClassificationModel(metadata.uid,
+ root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
+ metadata.getAndSetParams(model)
model
}
}
@@ -292,9 +296,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
- val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
// Can't infer number of features from old model, so default to -1
- new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
+ new DecisionTreeClassificationModel(uid,
+ rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index f11bc1d8fe415..337133a2e2326 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
- val oldDataset: RDD[LabeledPoint] =
+ val convert2LabeledPoint = (dataset: Dataset[_]) => {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
@@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") (
s" GBTClassifier currently only supports binary classification.")
LabeledPoint(label, features)
}
- val numFeatures = oldDataset.first().features.size
+ }
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (
+ convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
+ convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
+ )
+ } else {
+ (convert2LabeledPoint(dataset), null)
+ }
+
+ val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val numClasses = 2
@@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
+ seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
+ validationIndicatorCol)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(numClasses)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed), $(featureSubsetStrategy))
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
+ }
+
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
@@ -267,7 +293,7 @@ class GBTClassificationModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
super.predict(features)
@@ -334,6 +360,21 @@ class GBTClassificationModel private[ml](
// hard coded loss, which is not meant to be changed in the model
private val loss = getOldLossType
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
+ OldAlgo.Classification
+ )
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
@@ -371,22 +412,22 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
- val tree =
- new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
- DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+ root.asInstanceOf[RegressionNode], numFeatures)
+ treeMetadata.getAndSetParams(tree)
tree
}
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index ce400f4f1faf7..38eb04556b775 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") (
Instance(label, weight, features)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold,
aggregationDepth)
@@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
+ instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
+ instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") (
bcFeaturesStd.destroy(blocking = false)
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -316,7 +319,7 @@ class LinearSVCModel private[classification] (
BLAS.dot(features, coefficients) + intercept
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
if (margin(features) > $(threshold)) 1.0 else 0.0
}
@@ -377,7 +380,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
val Row(coefficients: Vector, intercept: Double) =
data.select("coefficients", "intercept").head()
val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fa191604218db..06ca37bc75146 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") (
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)
@@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") (
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
+ instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count)
+ instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
+ instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") (
if (numInvalid != 0) {
val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
s"Found $numInvalid invalid labels."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
val isConstantLabel = histogram.count(_ != 0.0) == 1
if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
- logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
- s"will be zeros. Training is not needed.")
+ instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " +
+ s"coefficients will be zeros. Training is not needed.")
val constantLabelIndex = Vectors.dense(histogram).argmax
val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures,
new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double],
@@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") (
(coefMatrix, interceptVec, Array.empty[Double])
} else {
if (!$(fitIntercept) && isConstantLabel) {
- logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
+ instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
s"dangerous ground, so the algorithm may not converge.")
}
@@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") (
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
- logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
+ instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
"constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
"nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
}
@@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") (
(_initialModel.interceptVector.size == numCoefficientSets) &&
(_initialModel.getFitIntercept == $(fitIntercept))
if (!modelIsValid) {
- logWarning(s"Initial coefficients will be ignored! Its dimensions " +
+ instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " +
s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
s"expected size ($numCoefficientSets, $numFeatures)")
}
@@ -813,7 +816,7 @@ class LogisticRegression @Since("1.2.0") (
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -1090,7 +1093,7 @@ class LogisticRegressionModel private[spark] (
* Predict label for the given feature vector.
* The behavior of this can be adjusted using `thresholds`.
*/
- override protected def predict(features: Vector): Double = if (isMultinomial) {
+ override def predict(features: Vector): Double = if (isMultinomial) {
super.predict(features)
} else {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
@@ -1267,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
numClasses, isMultinomial)
}
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index fd4c98f22132f..57ba47e596a97 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
* Predict label for the given features.
* This internal method is used to implement `transform()` and output [[predictionCol]].
*/
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
LabelConverter.decodeLabel(mlpModel.predict(features))
}
@@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 0293e03d47435..1dde18d2d1a31 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") (
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
+ val instr = Instrumentation.create(this, dataset)
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
+ instr.logNumClasses(numClasses)
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
@@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") (
}
}
- val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
probabilityCol, modelType, smoothing, thresholds)
@@ -407,7 +408,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
.head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f04fde2cbbca1..3474b61e40136 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
-private[ml] trait OneVsRestParams extends PredictorParams
+private[ml] trait OneVsRestParams extends ClassifierParams
with ClassifierTypeTrait with HasWeightCol {
/**
@@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
+ require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
+
+ @Since("2.4.0")
+ val numClasses: Int = models.length
+
+ @Since("2.4.0")
+ val numFeatures: Int = models.head.numFeatures
+
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
+
model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
@@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
- // output the index of the classifier with highest confidence as prediction
- val labelUDF = udf { (predictions: Map[Int, Double]) =>
- predictions.maxBy(_._2)._1.toDouble
- }
+ if (getRawPredictionCol != "") {
+ val numClass = models.length
- // output label and label metadata as prediction
- aggregatedDataset
- .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
- .drop(accColName)
+ // output the RawPrediction as vector
+ val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+ val predArray = Array.fill[Double](numClass)(0.0)
+ predictions.foreach { case (idx, value) => predArray(idx) = value }
+ Vectors.dense(predArray)
+ }
+
+ // output the index of the classifier with highest confidence as prediction
+ val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
+
+ // output confidence as raw prediction, label and label metadata as prediction
+ aggregatedDataset
+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
+ .drop(accColName)
+ } else {
+ // output the index of the classifier with highest confidence as prediction
+ val labelUDF = udf { (predictions: Map[Int, Double]) =>
+ predictions.maxBy(_._2)._1.toDouble
+ }
+ // output label and label metadata as prediction
+ aggregatedDataset
+ .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
+ .drop(accColName)
+ }
}
@Since("1.4.1")
@@ -257,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
- DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+ metadata.getAndSetParams(ovrModel)
ovrModel.set("classifier", classifier)
ovrModel
}
@@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
/**
* The implementation of parallel one vs. rest runs the classification for
* each class in a separate threads.
@@ -330,7 +366,7 @@ final class OneVsRest @Since("1.4.0") (
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
- instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
+ instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
@@ -347,7 +383,7 @@ final class OneVsRest @Since("1.4.0") (
getClassifier match {
case _: HasWeightCol => true
case c =>
- logWarning(s"weightCol is ignored, as it is not supported by $c now.")
+ instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}
@@ -448,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val ovr = new OneVsRest(metadata.uid)
- DefaultParamsReader.getAndSetParams(ovr, metadata)
+ metadata.getAndSetParams(ovr)
ovr.setClassifier(classifier)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 78a4972adbdbb..040db3b94b041 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") (
set(featureSubsetStrategy, value)
override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
+ val instr = Instrumentation.create(this, dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
@@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") (
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
@@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
instr.logSuccess(m)
m
}
@@ -310,23 +312,23 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
- EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
- val tree =
- new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
- DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
+ root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
+ treeMetadata.getAndSetParams(tree)
tree
}
require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 4c20e6563bad1..9c9614509c64f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -22,24 +22,23 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
-import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
+ BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
/**
* Common params for BisectingKMeans and BisectingKMeansModel
*/
-private[clustering] trait BisectingKMeansParams extends Params
- with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
+ with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure {
/**
* The desired number of leaf clusters. Must be > 1. Default: 4.
@@ -74,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
@@ -104,11 +103,16 @@ class BisectingKMeansModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group expertSetParam */
+ @Since("2.4.0")
+ def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ dataset.withColumn($(predictionCol),
+ predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
}
@Since("2.0.0")
@@ -127,9 +131,9 @@ class BisectingKMeansModel private[ml] (
*/
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
- SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
- val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
- parentModel.computeCost(data.map(OldVectors.fromML))
+ SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
+ val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
+ parentModel.computeCost(data)
}
@Since("2.0.0")
@@ -188,7 +192,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
val model = new BisectingKMeansModel(metadata.uid, mllibModel)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
@@ -248,26 +252,31 @@ class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0")
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
+ /** @group expertSetParam */
+ @Since("2.4.0")
+ def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
transformSchema(dataset.schema, logging = true)
- val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
- case Row(point: Vector) => OldVectors.fromML(point)
- }
+ val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
- val instr = Instrumentation.create(this, rdd)
- instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(featuresCol, predictionCol, k, maxIter, seed,
+ minDivisibleClusterSize, distanceMeasure)
val bkm = new MLlibBisectingKMeans()
.setK($(k))
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
+ .setDistanceMeasure($(distanceMeasure))
val parentModel = bkm.run(rdd)
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ instr.logNamedValue("clusterSizes", summary.clusterSizes)
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index f19ad7a5a6938..64ecc1ebda589 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr
Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
}
@@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] (
transformSchema(dataset.schema, logging = true)
val predUDF = udf((vector: Vector) => predict(vector))
val probUDF = udf((vector: Vector) => predictProbability(vector))
- dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
- .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
+ dataset
+ .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
+ .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
}
@Since("2.0.0")
@@ -233,7 +234,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
}
val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
@@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") (
val sc = dataset.sparkSession.sparkContext
val numClusters = $(k)
- val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map {
+ val instances: RDD[Vector] = dataset
+ .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map {
case Row(features: Vector) => features
}.cache()
@@ -350,7 +352,7 @@ class GaussianMixture @Since("2.0.0") (
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
instr.logNumFeatures(numFeatures)
@@ -423,6 +425,8 @@ class GaussianMixture @Since("2.0.0") (
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood)
model.setSummary(Some(summary))
+ instr.logNamedValue("logLikelihood", logLikelihood)
+ instr.logNamedValue("clusterSizes", summary.clusterSizes)
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index c8145de564cbe..1704412741d49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -17,12 +17,14 @@
package org.apache.spark.ml.clustering
+import scala.collection.mutable
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.{Estimator, Model, PipelineStage}
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
@@ -30,8 +32,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
-import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
+import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
@@ -40,7 +42,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
- with HasSeed with HasPredictionCol with HasTol {
+ with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
@@ -71,15 +73,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitMode: String = $(initMode)
- @Since("2.4.0")
- final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
- "Supported options: 'euclidean' and 'cosine'.",
- (value: String) => MLlibKMeans.validateDistanceMeasure(value))
-
- /** @group expertGetParam */
- @Since("2.4.0")
- def getDistanceMeasure: String = $(distanceMeasure)
-
/**
* Param for the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 2 is almost always enough. Must be > 0. Default: 2.
@@ -99,7 +92,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
@@ -112,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
- private val parentModel: MLlibKMeansModel)
- extends Model[KMeansModel] with KMeansParams with MLWritable {
+ private[clustering] val parentModel: MLlibKMeansModel)
+ extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
@@ -132,8 +125,11 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
+
val predictUDF = udf((vector: Vector) => predict(vector))
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+
+ dataset.withColumn($(predictionCol),
+ predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
}
@Since("1.5.0")
@@ -153,22 +149,20 @@ class KMeansModel private[ml] (
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
- SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
- val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
- case Row(point: Vector) => OldVectors.fromML(point)
- }
+ SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol)
+ val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
parentModel.computeCost(data)
}
/**
- * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+ * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
*
* For [[KMeansModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
*
*/
@Since("1.6.0")
- override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+ override def write: GeneralMLWriter = new GeneralMLWriter(this)
private var trainingSummary: Option[KMeansSummary] = None
@@ -194,6 +188,47 @@ class KMeansModel private[ml] (
}
}
+/** Helper class for storing model data */
+private case class ClusterData(clusterIdx: Int, clusterCenter: Vector)
+
+
+/** A writer for KMeans that handles the "internal" (or default) format */
+private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "internal"
+ override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[KMeansModel]
+ val sc = sparkSession.sparkContext
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: cluster centers
+ val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
+ case (center, idx) =>
+ ClusterData(idx, center)
+ }
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
+ }
+}
+
+/** A writer for KMeans that handles the "pmml" format */
+private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "pmml"
+ override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel"
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[KMeansModel]
+ val sc = sparkSession.sparkContext
+ instance.parentModel.toPMML(sc, path)
+ }
+}
+
+
@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {
@@ -203,30 +238,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)
- /** Helper class for storing model data */
- private case class Data(clusterIdx: Int, clusterCenter: Vector)
-
/**
* We store all cluster centers in a single row and use this class to store model data by
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
*/
private case class OldData(clusterCenters: Array[OldVector])
- /** [[MLWriter]] instance for [[KMeansModel]] */
- private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: cluster centers
- val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
- Data(idx, center)
- }
- val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
- }
- }
-
private class KMeansModelReader extends MLReader[KMeansModel] {
/** Checked against metadata when loading model */
@@ -241,14 +258,14 @@ object KMeansModel extends MLReadable[KMeansModel] {
val dataPath = new Path(path, "data").toString
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
- val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
+ val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
@@ -319,15 +336,13 @@ class KMeans @Since("1.5.0") (
transformSchema(dataset.schema, logging = true)
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
- val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
- case Row(point: Vector) => OldVectors.fromML(point)
- }
+ val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
val algo = new MLlibKMeans()
@@ -344,6 +359,7 @@ class KMeans @Since("1.5.0") (
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ instr.logNamedValue("clusterSizes", summary.clusterSizes)
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 4bab670cc159f..fed42c959b5ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType}
import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils
@@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
@@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
private object LDAParams {
/**
- * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
+ * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
* formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
*
* @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with
@@ -391,7 +391,7 @@ private object LDAParams {
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
}
case _ => // 2.0+
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
}
}
}
@@ -461,7 +461,8 @@ abstract class LDAModel private[ml] (
val transformer = oldLocalModel.getTopicDistributionMethod
val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML }
- dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF()
+ dataset.withColumn($(topicDistributionCol),
+ t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF()
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
@@ -568,10 +569,14 @@ abstract class LDAModel private[ml] (
class LocalLDAModel private[ml] (
uid: String,
vocabSize: Int,
- @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel,
+ private[clustering] val oldLocalModel_ : OldLocalLDAModel,
sparkSession: SparkSession)
extends LDAModel(uid, vocabSize, sparkSession) {
+ override private[clustering] def oldLocalModel: OldLocalLDAModel = {
+ oldLocalModel_.setSeed(getSeed)
+ }
+
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
@@ -938,7 +943,7 @@ object LDA extends MLReadable[LDA] {
featuresCol: String): RDD[(Long, OldVector)] = {
dataset
.withColumn("docId", monotonically_increasing_id())
- .select("docId", featuresCol)
+ .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol))
.rdd
.map { case Row(docId: Long, features: Vector) =>
(docId, OldVectors.fromML(features))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
new file mode 100644
index 0000000000000..1b9a3499947d9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types._
+
+/**
+ * Common params for PowerIterationClustering
+ */
+private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter
+ with HasWeightCol {
+
+ /**
+ * The number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ @Since("2.4.0")
+ final val k = new IntParam(this, "k", "The number of clusters to create. " +
+ "Must be > 1.", ParamValidators.gt(1))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getK: Int = $(k)
+
+ /**
+ * Param for the initialization algorithm. This can be either "random" to use a random vector
+ * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices.
+ * Default: random.
+ * @group expertParam
+ */
+ @Since("2.4.0")
+ final val initMode = {
+ val allowedParams = ParamValidators.inArray(Array("random", "degree"))
+ new Param[String](this, "initMode", "The initialization algorithm. This can be either " +
+ "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " +
+ "of similarities with other vertices. Supported options: 'random' and 'degree'.",
+ allowedParams)
+ }
+
+ /** @group expertGetParam */
+ @Since("2.4.0")
+ def getInitMode: String = $(initMode)
+
+ /**
+ * Param for the name of the input column for source vertex IDs.
+ * Default: "src"
+ * @group param
+ */
+ @Since("2.4.0")
+ val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.",
+ (value: String) => value.nonEmpty)
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getSrcCol: String = getOrDefault(srcCol)
+
+ /**
+ * Name of the input column for destination vertex IDs.
+ * Default: "dst"
+ * @group param
+ */
+ @Since("2.4.0")
+ val dstCol = new Param[String](this, "dstCol",
+ "Name of the input column for destination vertex IDs.",
+ (value: String) => value.nonEmpty)
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getDstCol: String = $(dstCol)
+
+ setDefault(srcCol -> "src", dstCol -> "dst")
+}
+
+/**
+ * :: Experimental ::
+ * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
+ * Lin and Cohen. From the abstract:
+ * PIC finds a very low-dimensional embedding of a dataset using truncated power
+ * iteration on a normalized pair-wise similarity matrix of the data.
+ *
+ * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the
+ * PowerIterationClustering algorithm.
+ *
+ * @see
+ * Spectral clustering (Wikipedia)
+ */
+@Since("2.4.0")
+@Experimental
+class PowerIterationClustering private[clustering] (
+ @Since("2.4.0") override val uid: String)
+ extends PowerIterationClusteringParams with DefaultParamsWritable {
+
+ setDefault(
+ k -> 2,
+ maxIter -> 20,
+ initMode -> "random")
+
+ @Since("2.4.0")
+ def this() = this(Identifiable.randomUID("PowerIterationClustering"))
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group expertSetParam */
+ @Since("2.4.0")
+ def setInitMode(value: String): this.type = set(initMode, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setSrcCol(value: String): this.type = set(srcCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setDstCol(value: String): this.type = set(dstCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ /**
+ * Run the PIC algorithm and returns a cluster assignment for each input vertex.
+ *
+ * @param dataset A dataset with columns src, dst, weight representing the affinity matrix,
+ * which is the matrix A in the PIC paper. Suppose the src column value is i,
+ * the dst column value is j, the weight column value is similarity s,,ij,,
+ * which must be nonnegative. This is a symmetric matrix and hence
+ * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
+ * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
+ * ignored, because we assume s,,ij,, = 0.0.
+ *
+ * @return A dataset that contains columns of vertex id and the corresponding cluster for the id.
+ * The schema of it will be:
+ * - id: Long
+ * - cluster: Int
+ */
+ @Since("2.4.0")
+ def assignClusters(dataset: Dataset[_]): DataFrame = {
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
+ lit(1.0)
+ } else {
+ col($(weightCol)).cast(DoubleType)
+ }
+
+ SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType))
+ SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType))
+ val rdd: RDD[(Long, Long, Double)] = dataset.select(
+ col($(srcCol)).cast(LongType),
+ col($(dstCol)).cast(LongType),
+ w).rdd.map {
+ case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight)
+ }
+ val algorithm = new MLlibPowerIterationClustering()
+ .setK($(k))
+ .setInitializationMode($(initMode))
+ .setMaxIterations($(maxIter))
+ val model = algorithm.run(rdd)
+
+ import dataset.sparkSession.implicits._
+ model.assignments.toDF
+ }
+
+ @Since("2.4.0")
+ override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra)
+}
+
+@Since("2.4.0")
+object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] {
+
+ @Since("2.4.0")
+ override def load(path: String): PowerIterationClustering = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index d6ec5223237bb..4353c46781e9d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -20,11 +20,13 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable,
+ SchemaUtils}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{avg, col, udf}
import org.apache.spark.sql.types.DoubleType
@@ -32,15 +34,11 @@ import org.apache.spark.sql.types.DoubleType
* :: Experimental ::
*
* Evaluator for clustering results.
- * The metric computes the Silhouette measure
- * using the squared Euclidean distance.
- *
- * The Silhouette is a measure for the validation
- * of the consistency within clusters. It ranges
- * between 1 and -1, where a value close to 1
- * means that the points in a cluster are close
- * to the other points in the same cluster and
- * far from the points of the other clusters.
+ * The metric computes the Silhouette measure using the specified distance measure.
+ *
+ * The Silhouette is a measure for the validation of the consistency within clusters. It ranges
+ * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the
+ * other points in the same cluster and far from the points of the other clusters.
*/
@Experimental
@Since("2.3.0")
@@ -84,18 +82,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
@Since("2.3.0")
def setMetricName(value: String): this.type = set(metricName, value)
- setDefault(metricName -> "silhouette")
+ /**
+ * param for distance measure to be used in evaluation
+ * (supports `"squaredEuclidean"` (default), `"cosine"`)
+ * @group param
+ */
+ @Since("2.4.0")
+ val distanceMeasure: Param[String] = {
+ val availableValues = Array("squaredEuclidean", "cosine")
+ val allowedParams = ParamValidators.inArray(availableValues)
+ new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " +
+ availableValues.mkString("'", "', '", "'"), allowedParams)
+ }
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getDistanceMeasure: String = $(distanceMeasure)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
+ setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean")
@Since("2.3.0")
override def evaluate(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))
- $(metricName) match {
- case "silhouette" =>
+ ($(metricName), $(distanceMeasure)) match {
+ case ("silhouette", "squaredEuclidean") =>
SquaredEuclideanSilhouette.computeSilhouetteScore(
- dataset, $(predictionCol), $(featuresCol)
- )
+ dataset, $(predictionCol), $(featuresCol))
+ case ("silhouette", "cosine") =>
+ CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol))
}
}
}
@@ -111,6 +131,57 @@ object ClusteringEvaluator
}
+private[evaluation] abstract class Silhouette {
+
+ /**
+ * It computes the Silhouette coefficient for a point.
+ */
+ def pointSilhouetteCoefficient(
+ clusterIds: Set[Double],
+ pointClusterId: Double,
+ pointClusterNumOfPoints: Long,
+ averageDistanceToCluster: (Double) => Double): Double = {
+ // Here we compute the average dissimilarity of the current point to any cluster of which the
+ // point is not a member.
+ // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current
+ // point - is said to be the "neighboring cluster".
+ val otherClusterIds = clusterIds.filter(_ != pointClusterId)
+ val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min
+
+ // adjustment for excluding the node itself from the computation of the average dissimilarity
+ val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) {
+ 0.0
+ } else {
+ averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints /
+ (pointClusterNumOfPoints - 1)
+ }
+
+ if (currentClusterDissimilarity < neighboringClusterDissimilarity) {
+ 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
+ } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) {
+ (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1
+ } else {
+ 0.0
+ }
+ }
+
+ /**
+ * Compute the mean Silhouette values of all samples.
+ */
+ def overallScore(df: DataFrame, scoreColumn: Column): Double = {
+ df.select(avg(scoreColumn)).collect()(0).getDouble(0)
+ }
+
+ protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = {
+ val group = AttributeGroup.fromStructField(dataFrame.schema(columnName))
+ if (group.size < 0) {
+ dataFrame.select(col(columnName)).first().getAs[Vector](0).size
+ } else {
+ group.size
+ }
+ }
+}
+
/**
* SquaredEuclideanSilhouette computes the average of the
* Silhouette over all the data of the dataset, which is
@@ -259,7 +330,7 @@ object ClusteringEvaluator
* `N` is the number of points in the dataset and `W` is the number
* of worker nodes.
*/
-private[evaluation] object SquaredEuclideanSilhouette {
+private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
private[this] var kryoRegistrationPerformed: Boolean = false
@@ -299,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette {
df: DataFrame,
predictionCol: String,
featuresCol: String): Map[Double, ClusterStats] = {
- val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size
+ val numFeatures = getNumberOfFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
.rdd
@@ -336,18 +407,19 @@ private[evaluation] object SquaredEuclideanSilhouette {
* It computes the Silhouette coefficient for a point.
*
* @param broadcastedClustersMap A map of the precomputed values for each cluster.
- * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point.
+ * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point.
* @param clusterId The id of the cluster the current point belongs to.
* @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point.
* @return The Silhouette for the point.
*/
def computeSilhouetteCoefficient(
broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]],
- features: Vector,
+ point: Vector,
clusterId: Double,
squaredNorm: Double): Double = {
- def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = {
+ def compute(targetClusterId: Double): Double = {
+ val clusterStats = broadcastedClustersMap.value(targetClusterId)
val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum)
squaredNorm +
@@ -355,41 +427,14 @@ private[evaluation] object SquaredEuclideanSilhouette {
2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints
}
- // Here we compute the average dissimilarity of the
- // current point to any cluster of which the point
- // is not a member.
- // The cluster with the lowest average dissimilarity
- // - i.e. the nearest cluster to the current point -
- // is said to be the "neighboring cluster".
- var neighboringClusterDissimilarity = Double.MaxValue
- broadcastedClustersMap.value.keySet.foreach {
- c =>
- if (c != clusterId) {
- val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c))
- if(dissimilarity < neighboringClusterDissimilarity) {
- neighboringClusterDissimilarity = dissimilarity
- }
- }
- }
- val currentCluster = broadcastedClustersMap.value(clusterId)
- // adjustment for excluding the node itself from
- // the computation of the average dissimilarity
- val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) {
- 0
- } else {
- compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints /
- (currentCluster.numOfPoints - 1)
- }
-
- (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match {
- case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
- case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1
- case 0 => 0.0
- }
+ pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
+ clusterId,
+ broadcastedClustersMap.value(clusterId).numOfPoints,
+ compute)
}
/**
- * Compute the mean Silhouette values of all samples.
+ * Compute the Silhouette score of the dataset using squared Euclidean distance measure.
*
* @param dataset The input dataset (previously clustered) on which compute the Silhouette.
* @param predictionCol The name of the column which contains the predicted cluster id
@@ -412,7 +457,7 @@ private[evaluation] object SquaredEuclideanSilhouette {
val clustersStatsMap = SquaredEuclideanSilhouette
.computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol)
- // Silhouette is reasonable only when the number of clusters is grater then 1
+ // Silhouette is reasonable only when the number of clusters is greater then 1
assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
@@ -421,13 +466,194 @@ private[evaluation] object SquaredEuclideanSilhouette {
computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
}
- val silhouetteScore = dfWithSquaredNorm
- .select(avg(
- computeSilhouetteCoefficientUDF(
- col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm"))
- ))
- .collect()(0)
- .getDouble(0)
+ val silhouetteScore = overallScore(dfWithSquaredNorm,
+ computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType),
+ col("squaredNorm")))
+
+ bClustersStatsMap.destroy()
+
+ silhouetteScore
+ }
+}
+
+
+/**
+ * The algorithm which is implemented in this object, instead, is an efficient and parallel
+ * implementation of the Silhouette using the cosine distance measure. The cosine distance
+ * measure is defined as `1 - s` where `s` is the cosine similarity between two points.
+ *
+ * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$`
+ * is:
+ *
+ *
+ *
+ * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension
+ * of the `i`-th point in cluster `$\Gamma$`.
+ *
+ * Then, we can define the vector:
+ *
+ *
+ *
+ * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`.
+ *
+ * With these definitions, the numerator becomes:
+ *
+ *
+ *
+ * In the implementation, the precomputed values for the clusters are distributed among the worker
+ * nodes via broadcasted variables, because we can assume that the clusters are limited in number.
+ *
+ * The main strengths of this algorithm are the low computational complexity and the intrinsic
+ * parallelism. The precomputed information for each point and for each cluster can be computed
+ * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the
+ * dataset and `W` is the number of worker nodes. After that, every point can be analyzed
+ * independently from the others.
+ *
+ * For every point we need to compute the average distance to all the clusters. Since the formula
+ * above requires `O(D)` operations, this phase has a computational complexity which is
+ * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number
+ * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker
+ * nodes.
+ */
+private[evaluation] object CosineSilhouette extends Silhouette {
+
+ private[this] val normalizedFeaturesColName = "normalizedFeatures"
+
+ /**
+ * The method takes the input dataset and computes the aggregated values
+ * about a cluster which are needed by the algorithm.
+ *
+ * @param df The DataFrame which contains the input data
+ * @param predictionCol The name of the column which contains the predicted cluster id
+ * for the point.
+ * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a
+ * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`).
+ */
+ def computeClusterStats(
+ df: DataFrame,
+ featuresCol: String,
+ predictionCol: String): Map[Double, (Vector, Long)] = {
+ val numFeatures = getNumberOfFeatures(df, featuresCol)
+ val clustersStatsRDD = df.select(
+ col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
+ .rdd
+ .map { row => (row.getDouble(0), row.getAs[Vector](1)) }
+ .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))(
+ seqOp = {
+ case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) =>
+ BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum)
+ (normalizedFeaturesSum, numOfPoints + 1)
+ },
+ combOp = {
+ case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) =>
+ BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1)
+ (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2)
+ }
+ )
+
+ clustersStatsRDD
+ .collectAsMap()
+ .toMap
+ }
+
+ /**
+ * It computes the Silhouette coefficient for a point.
+ *
+ * @param broadcastedClustersMap A map of the precomputed values for each cluster.
+ * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the
+ * normalized features of the current point.
+ * @param clusterId The id of the cluster the current point belongs to.
+ */
+ def computeSilhouetteCoefficient(
+ broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]],
+ normalizedFeatures: Vector,
+ clusterId: Double): Double = {
+
+ def compute(targetClusterId: Double): Double = {
+ val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId)
+ 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints
+ }
+
+ pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet,
+ clusterId,
+ broadcastedClustersMap.value(clusterId)._2,
+ compute)
+ }
+
+ /**
+ * Compute the Silhouette score of the dataset using the cosine distance measure.
+ *
+ * @param dataset The input dataset (previously clustered) on which compute the Silhouette.
+ * @param predictionCol The name of the column which contains the predicted cluster id
+ * for the point.
+ * @param featuresCol The name of the column which contains the feature vector of the point.
+ * @return The average of the Silhouette values of the clustered data.
+ */
+ def computeSilhouetteScore(
+ dataset: Dataset[_],
+ predictionCol: String,
+ featuresCol: String): Double = {
+ val normalizeFeatureUDF = udf {
+ features: Vector => {
+ val norm = Vectors.norm(features, 2.0)
+ features match {
+ case d: DenseVector => Vectors.dense(d.values.map(_ / norm))
+ case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm))
+ }
+ }
+ }
+ val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName,
+ normalizeFeatureUDF(col(featuresCol)))
+
+ // compute aggregate values for clusters needed by the algorithm
+ val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol,
+ predictionCol)
+
+ // Silhouette is reasonable only when the number of clusters is greater then 1
+ assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
+
+ val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
+
+ val computeSilhouetteCoefficientUDF = udf {
+ computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double)
+ }
+
+ val silhouetteScore = overallScore(dfWithNormalizedFeatures,
+ computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName),
+ col(predictionCol).cast(DoubleType)))
bClustersStatsMap.destroy()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 36a46ca6ff4b7..a906e954fecd5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml](
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
@@ -230,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
val model = new BucketedRandomProjectionLSHModel(metadata.uid,
randUnitVectors.rowIter.toArray)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index c13bf47eacb94..f99649f7fa164 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -19,6 +19,10 @@ package org.apache.spark.ml.feature
import java.{util => ju}
+import org.json4s.JsonDSL._
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 16abc4949dea3..dbfb199ccd58f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 60a4f918790a3..10c48c3f52085 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -70,19 +70,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
def getMinDF: Double = $(minDF)
/**
- * Specifies the maximum number of different documents a term must appear in to be included
- * in the vocabulary.
- * If this is an integer greater than or equal to 1, this specifies the number of documents
- * the term must appear in; if this is a double in [0,1), then this specifies the fraction of
- * documents.
+ * Specifies the maximum number of different documents a term could appear in to be included
+ * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an
+ * integer greater than or equal to 1, this specifies the maximum number of documents the term
+ * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of
+ * documents the term could appear in.
*
- * Default: (2^64^) - 1
+ * Default: (2^63^) - 1
* @group param
*/
val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" +
- " different documents a term must appear in to be included in the vocabulary." +
- " If this is an integer >= 1, this specifies the number of documents the term must" +
- " appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
+ " different documents a term could appear in to be included in the vocabulary." +
+ " A term that appears more than the threshold will be ignored. If this is an integer >= 1," +
+ " this specifies the maximum number of documents the term could appear in;" +
+ " if this is a double in [0,1), then this specifies the maximum fraction of" +
+ " documents the term could appear in.",
ParamValidators.gtEq(0.0))
/** @group getParam */
@@ -361,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
.head()
val vocabulary = data.getAs[Seq[String]](0).toArray
val model = new CountVectorizerModel(metadata.uid, vocabulary)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
index a918dd4c075da..d67e4819b161a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
@@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap
@@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme
@Since("2.3.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- val hashFunc: Any => Int = OldHashingTF.murmur3Hash
+ val hashFunc: Any => Int = FeatureHasher.murmur3Hash
val n = $(numFeatures)
val localInputCols = $(inputCols)
val catCols = if (isSet(categoricalCols)) {
@@ -218,4 +221,31 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] {
@Since("2.3.0")
override def load(path: String): FeatureHasher = super.load(path)
+
+ private val seed = OldHashingTF.seed
+
+ /**
+ * Calculate a hash code value for the term object using
+ * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32).
+ * This is the default hash algorithm used from Spark 2.0 onwards.
+ * Use hashUnsafeBytes2 to match the original algorithm with the value.
+ * See SPARK-23381.
+ */
+ @Since("2.3.0")
+ private[feature] def murmur3Hash(term: Any): Int = {
+ term match {
+ case null => seed
+ case b: Boolean => hashInt(if (b) 1 else 0, seed)
+ case b: Byte => hashInt(b, seed)
+ case s: Short => hashInt(s, seed)
+ case i: Int => hashInt(i, seed)
+ case l: Long => hashLong(l, seed)
+ case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
+ case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
+ case s: String =>
+ hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed)
+ case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " +
+ s"support type ${term.getClass.getCanonicalName} of input data.")
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 46a0730f5ddb8..58897cca4e5c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] {
.select("idf")
.head()
val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 730ee9fc08db8..1c074e204ad99 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 1c9f47a0b201d..a70931f783f45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
extends Model[T] with LSHParams with MLWritable {
self: T =>
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
/**
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 85f9732f79f67..90eceb0d61b40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
.select("maxAbs")
.head()
val model = new MaxAbsScalerModel(metadata.uid, maxAbs)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 145422a059196..a67a3b0abbc1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -51,6 +51,14 @@ class MinHashLSHModel private[ml](
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
elems: Vector => {
@@ -197,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
.map(tuple => (tuple(0), tuple(1))).toArray
val model = new MinHashLSHModel(metadata.uid, randCoefficients)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index f648deced54cd..2e0ae4af66f06 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
index bd1e3426c8780..4a44f3186538d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
@@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
.head()
val categorySizes = data.getAs[Seq[Int]](0).toArray
val model = new OneHotEncoderModel(metadata.uid, categorySizes)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 4143d864d7930..8172491a517d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] {
new PCAModel(metadata.uid, pc.asML,
Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 1ec5f8cb6139b..56e2c543d100a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -17,6 +17,10 @@
package org.apache.spark.ml.feature
+import org.json4s.JsonDSL._
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 22e7b8bbf1ff5..55e595eee6ffb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new VectorAssembler(uid)
.setInputCols(encodedTerms.toArray)
.setOutputCol($(featuresCol))
+ .setHandleInvalid($(handleInvalid))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
@@ -445,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
@@ -509,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
val columnsToPrune = data.getAs[Seq[String]](0).toSet
val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
- DefaultParamsReader.getAndSetParams(pruner, metadata)
+ metadata.getAndSetParams(pruner)
pruner
}
}
@@ -601,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
val prefixesToRewrite = data.getAs[Map[String, String]](1)
val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
- DefaultParamsReader.getAndSetParams(rewriter, metadata)
+ metadata.getAndSetParams(rewriter)
rewriter
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 8f125d8fd51d2..91b0707dec3f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 3fcd84c029e61..0f946dd2e015b 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -17,9 +17,11 @@
package org.apache.spark.ml.feature
+import java.util.Locale
+
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
@Since("1.5.0")
def getCaseSensitive: Boolean = $(caseSensitive)
- setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false)
+ /**
+ * Locale of the input for case insensitive matching. Ignored when [[caseSensitive]]
+ * is true.
+ * Default: Locale.getDefault.toString
+ * @group param
+ */
+ @Since("2.4.0")
+ val locale: Param[String] = new Param[String](this, "locale",
+ "Locale of the input for case insensitive matching. Ignored when caseSensitive is true.",
+ ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString)))
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setLocale(value: String): this.type = set(locale, value)
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getLocale: String = $(locale)
+
+ setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
+ caseSensitive -> false, locale -> Locale.getDefault.toString)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
@@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String
terms.filter(s => !stopWordsSet.contains(s))
}
} else {
- // TODO: support user locale (SPARK-15064)
- val toLower = (s: String) => if (s != null) s.toLowerCase else s
+ val lc = new Locale($(locale))
+ val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s
val lowerStopWords = $(stopWords).map(toLower(_)).toSet
udf { terms: Seq[String] =>
terms.filter(s => !lowerStopWords.contains(toLower(s)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 1cdcdfcaeab78..a833d8b270cf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
- val (filteredDataset, keepInvalid) = getHandleInvalid match {
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
@@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
.head()
val labels = data.getAs[Seq[String]](0).toArray
val model = new StringIndexerModel(metadata.uid, labels)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index b373ae921ed38..4061154b39c14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -17,14 +17,17 @@
package org.apache.spark.ml.feature
-import scala.collection.mutable.ArrayBuilder
+import java.util.NoSuchElementException
+
+import scala.collection.mutable
+import scala.language.existentials
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -33,10 +36,14 @@ import org.apache.spark.sql.types._
/**
* A feature transformer that merges multiple columns into a vector column.
+ *
+ * This requires one pass over the entire dataset. In case we need to infer column lengths from the
+ * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter.
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
+ extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
+ with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+ /**
+ * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
+ * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
+ * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
+ * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
+ * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.4.0")
+ override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
+ """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
+ |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
+ |in the output). Column lengths are taken from the size of ML Attribute Group, which can be
+ |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
+ |be inferred from first rows of the data since it is safe to do so but only in case of 'error'
+ |or 'skip'.""".stripMargin.replaceAll("\n", " "),
+ ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
+
+ setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.toDF.first()
- val attrs = $(inputCols).flatMap { c =>
+
+ val vectorCols = $(inputCols).filter { c =>
+ schema(c).dataType match {
+ case _: VectorUDT => true
+ case _ => false
+ }
+ }
+ val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid))
+
+ val featureAttributesMap = $(inputCols).map { c =>
val field = schema(c)
- val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
- val attr = Attribute.fromStructField(field)
- // If the input column doesn't have ML attribute, assume numeric.
- if (attr == UnresolvedAttribute) {
- Some(NumericAttribute.defaultAttr.withName(c))
- } else {
- Some(attr.withName(c))
+ val attribute = Attribute.fromStructField(field)
+ attribute match {
+ case UnresolvedAttribute =>
+ Seq(NumericAttribute.defaultAttr.withName(c))
+ case _ =>
+ Seq(attribute.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
- Some(NumericAttribute.defaultAttr.withName(c))
+ Seq(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
- val group = AttributeGroup.fromStructField(field)
- if (group.attributes.isDefined) {
- // If attributes are defined, copy them with updated names.
- group.attributes.get.zipWithIndex.map { case (attr, i) =>
+ val attributeGroup = AttributeGroup.fromStructField(field)
+ if (attributeGroup.attributes.isDefined) {
+ attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
@@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
- val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
- Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
+ (0 until vectorColsLengths(c)).map { i =>
+ NumericAttribute.defaultAttr.withName(c + "_" + i)
+ }
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
- val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
-
+ val featureAttributes = featureAttributesMap.flatten[Attribute].toArray
+ val lengths = featureAttributesMap.map(a => a.length).toArray
+ val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata()
+ val (filteredDataset, keepInvalid) = $(handleInvalid) match {
+ case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false)
+ case VectorAssembler.KEEP_INVALID => (dataset, true)
+ case VectorAssembler.ERROR_INVALID => (dataset, false)
+ }
// Data transformation.
val assembleFunc = udf { r: Row =>
- VectorAssembler.assemble(r.toSeq: _*)
+ VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
@@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}
- dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
+ filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}
@Since("1.4.0")
@@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
+ /**
+ * Infers lengths of vector columns from the first row of the dataset
+ * @param dataset the dataset
+ * @param columns name of vector columns whose lengths need to be inferred
+ * @return map of column names to lengths
+ */
+ private[feature] def getVectorLengthsFromFirstRow(
+ dataset: Dataset[_],
+ columns: Seq[String]): Map[String, Int] = {
+ try {
+ val first_row = dataset.toDF().select(columns.map(col): _*).first()
+ columns.zip(first_row.toSeq).map {
+ case (c, x) => c -> x.asInstanceOf[Vector].size
+ }.toMap
+ } catch {
+ case e: NullPointerException => throw new NullPointerException(
+ s"""Encountered null value while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ case e: NoSuchElementException => throw new NoSuchElementException(
+ s"""Encountered empty dataframe while inferring lengths from the first row. Consider using
+ |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """
+ .stripMargin.replaceAll("\n", " ") + e.toString)
+ }
+ }
+
+ private[feature] def getLengths(
+ dataset: Dataset[_],
+ columns: Seq[String],
+ handleInvalid: String): Map[String, Int] = {
+ val groupSizes = columns.map { c =>
+ c -> AttributeGroup.fromStructField(dataset.schema(c)).size
+ }.toMap
+ val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq
+ val firstSizes = (missingColumns.nonEmpty, handleInvalid) match {
+ case (true, VectorAssembler.ERROR_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset, missingColumns)
+ case (true, VectorAssembler.SKIP_INVALID) =>
+ getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns)
+ case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException(
+ s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint
+ |to add metadata for columns: ${columns.mkString("[", ", ", "]")}."""
+ .stripMargin.replaceAll("\n", " "))
+ case (_, _) => Map.empty
+ }
+ groupSizes ++ firstSizes
+ }
+
+
@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)
- private[feature] def assemble(vv: Any*): Vector = {
- val indices = ArrayBuilder.make[Int]
- val values = ArrayBuilder.make[Double]
- var cur = 0
+ /**
+ * Returns a function that has the required information to assemble each row.
+ * @param lengths an array of lengths of input columns, whose size should be equal to the number
+ * of cells in the row (vv)
+ * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows
+ * @return a udf that can be applied on each row
+ */
+ private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
+ val indices = mutable.ArrayBuilder.make[Int]
+ val values = mutable.ArrayBuilder.make[Double]
+ var featureIndex = 0
+
+ var inputColumnIndex = 0
vv.foreach {
case v: Double =>
- if (v != 0.0) {
- indices += cur
+ if (v.isNaN && !keepInvalid) {
+ throw new SparkException(
+ s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider
+ |removing NaNs from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ } else if (v != 0.0) {
+ indices += featureIndex
values += v
}
- cur += 1
+ inputColumnIndex += 1
+ featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
- indices += cur + i
+ indices += featureIndex + i
values += v
}
}
- cur += vec.size
+ inputColumnIndex += 1
+ featureIndex += vec.size
case null =>
- // TODO: output Double.NaN?
- throw new SparkException("Values to assemble cannot be null.")
+ if (keepInvalid) {
+ val length: Int = lengths(inputColumnIndex)
+ Array.range(0, length).foreach { i =>
+ indices += featureIndex + i
+ values += Double.NaN
+ }
+ inputColumnIndex += 1
+ featureIndex += length
+ } else {
+ throw new SparkException(
+ s"""Encountered null while assembling a row with handleInvalid = "keep". Consider
+ |removing nulls from dataset or using handleInvalid = "keep" or "skip"."""
+ .stripMargin)
+ }
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
- Vectors.sparse(cur, indices.result(), values.result()).compressed
+ Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index e6ec4e2e36ff0..0e7396a621dbd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
val numFeatures = data.getAs[Int](0)
val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index fe3306e1e50d6..fc9996d69ba72 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
}
val model = new Word2VecModel(metadata.uid, oldModel)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index aa7871d6ff29d..d7fbe28ae7a64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
/**
* Common params for FPGrowth and FPGrowthModel
@@ -158,19 +159,35 @@ class FPGrowth @Since("2.2.0") (
}
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
+
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(params: _*)
val data = dataset.select($(itemsCol))
- val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
+ val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
if (isSet(numPartitions)) {
mllibFP.setNumPartitions($(numPartitions))
}
+
+ if (handlePersistence) {
+ items.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
val parentModel = mllibFP.run(items)
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
val schema = StructType(Seq(
StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
- copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
+
+ if (handlePersistence) {
+ items.unpersist()
+ }
+
+ val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
+ instr.logSuccess(model)
+ model
}
@Since("2.2.0")
@@ -322,7 +339,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val model = new FPGrowthModel(metadata.uid, frequentItems)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
new file mode 100644
index 0000000000000..bd1c1a8885201
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.fpm
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
+
+/**
+ * :: Experimental ::
+ * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ * Efficiently by Prefix-Projected Pattern Growth
+ * (see here).
+ * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
+ * run the PrefixSpan algorithm.
+ *
+ * @see Sequential Pattern Mining
+ * (Wikipedia)
+ */
+@Since("2.4.0")
+@Experimental
+final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
+
+ @Since("2.4.0")
+ def this() = this(Identifiable.randomUID("prefixSpan"))
+
+ /**
+ * Param for the minimal support level (default: `0.1`).
+ * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
+ * identified as frequent sequential patterns.
+ * @group param
+ */
+ @Since("2.4.0")
+ val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " +
+ "sequential pattern. Sequential pattern that appears more than " +
+ "(minSupport * size-of-the-dataset) " +
+ "times will be output.", ParamValidators.gtEq(0.0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMinSupport: Double = $(minSupport)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMinSupport(value: Double): this.type = set(minSupport, value)
+
+ /**
+ * Param for the maximal pattern length (default: `10`).
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxPatternLength = new IntParam(this, "maxPatternLength",
+ "The maximal length of the sequential pattern.",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMaxPatternLength: Int = $(maxPatternLength)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
+
+ /**
+ * Param for the maximum number of items (including delimiters used in the internal storage
+ * format) allowed in a projected database before local processing (default: `32000000`).
+ * If a projected database exceeds this size, another iteration of distributed prefix growth
+ * is run.
+ * @group param
+ */
+ @Since("2.4.0")
+ val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the internal storage format) " +
+ "allowed in a projected database before local processing. If a projected database exceeds " +
+ "this size, another iteration of distributed prefix growth is run.",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
+
+ /**
+ * Param for the name of the sequence column in dataset (default "sequence"), rows with
+ * nulls in this column are ignored.
+ * @group param
+ */
+ @Since("2.4.0")
+ val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " +
+ "dataset, rows with nulls in this column are ignored.")
+
+ /** @group getParam */
+ @Since("2.4.0")
+ def getSequenceCol: String = $(sequenceCol)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setSequenceCol(value: String): this.type = set(sequenceCol, value)
+
+ setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
+ sequenceCol -> "sequence")
+
+ /**
+ * :: Experimental ::
+ * Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
+ *
+ * @param dataset A dataset or a dataframe containing a sequence column which is
+ * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset.
+ * @return A `DataFrame` that contains columns of sequence and corresponding frequency.
+ * The schema of it will be:
+ * - `sequence: ArrayType(ArrayType(T))` (T is the item type)
+ * - `freq: Long`
+ */
+ @Since("2.4.0")
+ def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
+ val sequenceColParam = $(sequenceCol)
+ val inputType = dataset.schema(sequenceColParam).dataType
+ require(inputType.isInstanceOf[ArrayType] &&
+ inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
+ s"The input column must be ArrayType and the array element type must also be ArrayType, " +
+ s"but got $inputType.")
+
+ val data = dataset.select(sequenceColParam)
+ val sequences = data.where(col(sequenceColParam).isNotNull).rdd
+ .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)
+
+ val mllibPrefixSpan = new mllibPrefixSpan()
+ .setMinSupport($(minSupport))
+ .setMaxPatternLength($(maxPatternLength))
+ .setMaxLocalProjDBSize($(maxLocalProjDBSize))
+
+ val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq))
+ val schema = StructType(Seq(
+ StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false),
+ StructField("freq", LongType, nullable = false)))
+ val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
+
+ freqSequences
+ }
+
+ @Since("2.4.0")
+ override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
index 6961b45f55e4d..572b8cf0051b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.optim
-import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.util.OptionalInstrumentation
import org.apache.spark.rdd.RDD
/**
@@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares(
val fitIntercept: Boolean,
val regParam: Double,
val maxIter: Int,
- val tol: Double) extends Logging with Serializable {
+ val tol: Double) extends Serializable {
- def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = {
+ def fit(
+ instances: RDD[OffsetInstance],
+ instr: OptionalInstrumentation = OptionalInstrumentation.create(
+ classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = {
var converged = false
var iter = 0
@@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares(
// Estimate new model
model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0,
- standardizeFeatures = false, standardizeLabel = false).fit(newInstances)
+ standardizeFeatures = false, standardizeLabel = false)
+ .fit(newInstances, instr = instr)
// Check convergence
val oldCoefficients = oldModel.coefficients
@@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares(
if (maxTol < tol) {
converged = true
- logInfo(s"IRLS converged in $iter iterations.")
+ instr.logInfo(s"IRLS converged in $iter iterations.")
}
- logInfo(s"Iteration $iter : relative tolerance = $maxTol")
+ instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol")
iter = iter + 1
if (iter == maxIter) {
- logInfo(s"IRLS reached the max number of iterations: $maxIter.")
+ instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index c5c9c8eb2bd29..1b7c15f1f0a8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.optim
-import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.util.OptionalInstrumentation
import org.apache.spark.rdd.RDD
/**
@@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares(
val standardizeLabel: Boolean,
val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto,
val maxIter: Int = 100,
- val tol: Double = 1e-6) extends Logging with Serializable {
+ val tol: Double = 1e-6
+ ) extends Serializable {
import WeightedLeastSquares._
require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
- if (regParam == 0.0) {
- logWarning("regParam is zero, which might cause numerical instability and overfitting.")
- }
require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0,
s"elasticNetParam must be in [0, 1]: $elasticNetParam")
require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter")
@@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares(
/**
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
*/
- def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
+ def fit(
+ instances: RDD[Instance],
+ instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares])
+ ): WeightedLeastSquaresModel = {
+ if (regParam == 0.0) {
+ instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.")
+ }
+
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
summary.validate()
- logInfo(s"Number of instances: ${summary.count}.")
+ instr.logInfo(s"Number of instances: ${summary.count}.")
val k = if (fitIntercept) summary.k + 1 else summary.k
val numFeatures = summary.k
val triK = summary.triK
@@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares(
if (rawBStd == 0) {
if (fitIntercept || rawBBar == 0.0) {
if (rawBBar == 0.0) {
- logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
- s"and the intercept will all be zero; as a result, training is not needed.")
+ instr.logWarning(s"Mean and standard deviation of the label are zero, so the " +
+ s"coefficients and the intercept will all be zero; as a result, training is not " +
+ s"needed.")
} else {
- logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
- s"zeros and the intercept will be the mean of the label; as a result, " +
+ instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " +
+ s"will be zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
val coefficients = new DenseVector(Array.ofDim(numFeatures))
@@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares(
} else {
require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " +
"zero. Model cannot be regularized with standardization=true")
- logWarning(s"The standard deviation of the label is zero. Consider setting " +
+ instr.logWarning(s"The standard deviation of the label is zero. Consider setting " +
s"fitIntercept=true.")
}
}
@@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares(
// if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
// Quasi-Newton solver.
case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
- logWarning("Cholesky solver failed due to singular covariance matrix. " +
+ instr.logWarning("Cholesky solver failed due to singular covariance matrix. " +
"Retrying with Quasi-Newton solver.")
// ab and aa were modified in place, so reconstruct them
val _aa = getAtA(aaBarValues, aBarValues)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 9a83a5882ce29..e6c347ed17c15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable {
}
/** Internal param map for user-supplied values. */
- private val paramMap: ParamMap = ParamMap.empty
+ private[ml] val paramMap: ParamMap = ParamMap.empty
/** Internal param map for default values. */
- private val defaultParamMap: ParamMap = ParamMap.empty
+ private[ml] val defaultParamMap: ParamMap = ParamMap.empty
/** Validates that the input param belongs to this instance. */
private def shouldOwn(param: Param[_]): Unit = {
@@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable {
}
}
+private[ml] object Params {
+ /**
+ * Sets a default param value for a `Params`.
+ */
+ private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
+ params.defaultParamMap.put(param -> value)
+ }
+}
+
/**
* :: DeveloperApi ::
* Java-friendly wrapper for [[Params]].
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 6ad44af9ef7eb..7e08675f834da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -91,7 +91,14 @@ private[shared] object SharedParamsCodeGen {
"after fitting. If set to true, then all sub-models will be available. Warning: For " +
"large models, collecting all sub-models can cause OOMs on the Spark driver",
Some("false"), isExpertParam = true),
- ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false)
+ ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false),
+ ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
+ " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
+ isValid = "(value: String) => " +
+ "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"),
+ ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " +
+ "each row is for training or for validation. False indicates training; true indicates " +
+ "validation.")
)
val code = genSharedParams(params)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index be8b2f273164b..5928a0749f738 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -504,4 +504,40 @@ trait HasLoss extends Params {
/** @group getParam */
final def getLoss: String = $(loss)
}
+
+/**
+ * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasDistanceMeasure extends Params {
+
+ /**
+ * Param for The distance measure. Supported options: 'euclidean' and 'cosine'.
+ * @group param
+ */
+ final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value))
+
+ setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN)
+
+ /** @group getParam */
+ final def getDistanceMeasure: String = $(distanceMeasure)
+}
+
+/**
+ * Trait for shared param validationIndicatorCol. This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasValidationIndicatorCol extends Params {
+
+ /**
+ * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation..
+ * @group param
+ */
+ final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.")
+
+ /** @group getParam */
+ final def getValidationIndicatorCol: String = $(validationIndicatorCol)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 81a8f50761e0e..a23f9552b9e5f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] {
val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 4b46c3831d75f..e27a96e1f5dfc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) {
- logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " +
+ instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " +
"constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " +
"columns. This behavior is different from R survival::survreg.")
}
@@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
.head()
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 0291a57487c47..8bcf0793a64c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
@Since("1.4.0")
class DecisionTreeRegressionModel private[ml] (
override val uid: String,
- override val rootNode: Node,
+ override val rootNode: RegressionNode,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
@@ -175,10 +175,10 @@ class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- private[ml] def this(rootNode: Node, numFeatures: Int) =
+ private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
}
@@ -279,9 +279,10 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
- val root = loadTreeNodes(path, metadata, sparkSession)
- val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
+ val model = new DecisionTreeRegressionModel(metadata.uid,
+ root.asInstanceOf[RegressionNode], numFeatures)
+ metadata.getAndSetParams(model)
model
}
}
@@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
- val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
- new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
+ new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index f41d15b62dddd..eb8b3c001436a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
/**
@@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
- val numFeatures = oldDataset.first().features.size
+
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (
+ extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))),
+ extractLabeledPoints(dataset.filter(col($(validationIndicatorCol))))
+ )
+ } else {
+ (extractLabeledPoints(dataset), null)
+ }
+ val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed), $(featureSubsetStrategy))
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ }
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
@@ -230,7 +251,7 @@ class GBTRegressionModel private[ml](
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
@@ -269,6 +290,21 @@ class GBTRegressionModel private[ml](
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ * @param loss The loss function used to compute error. Supported options: squared, absolute
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
+ convertToOldLossType(loss), OldAlgo.Regression)
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
@@ -302,16 +338,16 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
override def load(path: String): GBTRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
- val tree =
- new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
- DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+ root.asInstanceOf[RegressionNode], numFeatures)
+ treeMetadata.getAndSetParams(tree)
tree
}
@@ -319,7 +355,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 917a4d238d467..143c8a3548b1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0,
standardizeFeatures = true, standardizeLabel = true)
- val wlsModel = optimizer.fit(instances)
+ val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr))
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
.setParent(this))
@@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
OffsetInstance(label, weight, offset, features)
}
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
- val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
+ val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam),
+ instr = OptionalInstrumentation.create(instr))
val optimizer = new IterativelyReweightedLeastSquares(initialModel,
familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol))
- val irlsModel = optimizer.fit(instances)
+ val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr))
val model = copyValues(
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
.setParent(this))
@@ -471,6 +472,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
private[regression] val epsilon: Double = 1E-16
+ private[regression] def ylogy(y: Double, mu: Double): Double = {
+ if (y == 0) 0.0 else y * math.log(y / mu)
+ }
+
/**
* Wrapper of family and link combination used in the model.
*/
@@ -488,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
def initialize(
instances: RDD[OffsetInstance],
fitIntercept: Boolean,
- regParam: Double): WeightedLeastSquaresModel = {
+ regParam: Double,
+ instr: OptionalInstrumentation = OptionalInstrumentation.create(
+ classOf[GeneralizedLinearRegression])
+ ): WeightedLeastSquaresModel = {
val newInstances = instances.map { instance =>
val mu = family.initialize(instance.label, instance.weight)
val eta = predict(mu) - instance.offset
@@ -497,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
// TODO: Make standardizeFeatures and standardizeLabel configurable.
val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0,
standardizeFeatures = true, standardizeLabel = true)
- .fit(newInstances)
+ .fit(newInstances, instr)
initialModel
}
@@ -725,10 +733,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def variance(mu: Double): Double = mu * (1.0 - mu)
- private def ylogy(y: Double, mu: Double): Double = {
- if (y == 0) 0.0 else y * math.log(y / mu)
- }
-
override def deviance(y: Double, mu: Double, weight: Double): Double = {
2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu))
}
@@ -783,7 +787,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def variance(mu: Double): Double = mu
override def deviance(y: Double, mu: Double, weight: Double): Double = {
- 2.0 * weight * (y * math.log(y / mu) - (y - mu))
+ 2.0 * weight * (ylogy(y, mu) - (y - mu))
}
override def aic(
@@ -1010,7 +1014,7 @@ class GeneralizedLinearRegressionModel private[ml] (
private lazy val familyAndLink = FamilyAndLink(this)
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
predict(features, 0.0)
}
@@ -1146,7 +1150,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 8faab52ea474b..b046897ab2b7e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
val model = new IsotonicRegressionModel(
metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic))
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 6d3fe7a6c748c..c45ade94a4e33 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.{PipelineStage, PredictorParams}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.BLAS._
@@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -338,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
elasticNetParam = $(elasticNetParam), $(standardization), true,
solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol))
- val model = optimizer.fit(instances)
+ val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr))
// When it is trained by WeightedLeastSquares, training summary does not
// attach returned model.
val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
@@ -377,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val yMean = ySummarizer.mean(0)
val rawYStd = math.sqrt(ySummarizer.variance(0))
+
+ instr.logNumExamples(ySummarizer.count)
+ instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean)
+ instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd)
+
if (rawYStd == 0.0) {
if ($(fitIntercept) || yMean == 0.0) {
// If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with
@@ -384,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of
// the fitIntercept.
if (yMean == 0.0) {
- logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " +
- s"and the intercept will all be zero; as a result, training is not needed.")
+ instr.logWarning(s"Mean and standard deviation of the label are zero, so the " +
+ s"coefficients and the intercept will all be zero; as a result, training is not " +
+ s"needed.")
} else {
- logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
- s"zeros and the intercept will be the mean of the label; as a result, " +
+ instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " +
+ s"will be zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
if (handlePersistence) instances.unpersist()
@@ -414,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")
- logWarning(s"The standard deviation of the label is zero. " +
+ instr.logWarning(s"The standard deviation of the label is zero. " +
"Consider setting fitIntercept=true.")
}
}
@@ -429,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
- logWarning("Fitting LinearRegressionModel without intercept on dataset with " +
+ instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " +
"constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " +
"columns. This behavior is the same as R glmnet but different from LIBSVM.")
}
@@ -521,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
}
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
- logError(msg)
+ instr.logError(msg)
throw new SparkException(msg)
}
@@ -643,7 +650,7 @@ class LinearRegressionModel private[ml] (
@Since("1.3.0") val intercept: Double,
@Since("2.3.0") val scale: Double)
extends RegressionModel[Vector, LinearRegressionModel]
- with LinearRegressionParams with MLWritable {
+ with LinearRegressionParams with GeneralMLWritable {
private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
this(uid, coefficients, intercept, 1.0)
@@ -699,7 +706,7 @@ class LinearRegressionModel private[ml] (
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
dot(features, coefficients) + intercept
}
@@ -710,7 +717,7 @@ class LinearRegressionModel private[ml] (
}
/**
- * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
+ * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance.
*
* For [[LinearRegressionModel]], this does NOT currently save the training [[summary]].
* An option to save [[summary]] may be added in the future.
@@ -718,7 +725,50 @@ class LinearRegressionModel private[ml] (
* This also does not save the [[parent]] currently.
*/
@Since("1.6.0")
- override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this)
+ override def write: GeneralMLWriter = new GeneralMLWriter(this)
+}
+
+/** A writer for LinearRegression that handles the "internal" (or default) format */
+private class InternalLinearRegressionModelWriter
+ extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "internal"
+ override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+ private case class Data(intercept: Double, coefficients: Vector, scale: Double)
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val instance = stage.asInstanceOf[LinearRegressionModel]
+ val sc = sparkSession.sparkContext
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: intercept, coefficients, scale
+ val data = Data(instance.intercept, instance.coefficients, instance.scale)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+}
+
+/** A writer for LinearRegression that handles the "pmml" format */
+private class PMMLLinearRegressionModelWriter
+ extends MLWriterFormat with MLFormatRegister {
+
+ override def format(): String = "pmml"
+
+ override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+
+ private case class Data(intercept: Double, coefficients: Vector)
+
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ val sc = sparkSession.sparkContext
+ // Construct the MLLib model which knows how to write to PMML.
+ val instance = stage.asInstanceOf[LinearRegressionModel]
+ val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept)
+ // Save PMML
+ oldModel.toPMML(sc, path)
+ }
}
@Since("1.6.0")
@@ -730,22 +780,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
@Since("1.6.0")
override def load(path: String): LinearRegressionModel = super.load(path)
- /** [[MLWriter]] instance for [[LinearRegressionModel]] */
- private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel)
- extends MLWriter with Logging {
-
- private case class Data(intercept: Double, coefficients: Vector, scale: Double)
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: intercept, coefficients, scale
- val data = Data(instance.intercept, instance.coefficients, instance.scale)
- val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
- }
- }
-
private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {
/** Checked against metadata when loading model */
@@ -771,7 +805,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
new LinearRegressionModel(metadata.uid, coefficients, intercept, scale)
}
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 200b234b79978..4509f85aafd12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] (
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
- override protected def predict(features: Vector): Double = {
+ override def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
@@ -269,21 +269,21 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
- EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
+ EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
- val tree =
- new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
- DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
+ root.asInstanceOf[RegressionNode], numFeatures)
+ treeMetadata.getAndSetParams(tree)
tree
}
require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
- DefaultParamsReader.getAndSetParams(model, metadata)
+ metadata.getAndSetParams(model)
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
new file mode 100644
index 0000000000000..adf8145726711
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import scala.annotation.varargs
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.api.java.function.Function
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions.col
+
+/**
+ * :: Experimental ::
+ *
+ * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
+ * continuous distribution. By comparing the largest difference between the empirical cumulative
+ * distribution of the sample data and the theoretical distribution we can provide a test for the
+ * the null hypothesis that the sample data comes from that theoretical distribution.
+ * For more information on KS Test:
+ * @see
+ * Kolmogorov-Smirnov test (Wikipedia)
+ */
+@Experimental
+@Since("2.4.0")
+object KolmogorovSmirnovTest {
+
+ /** Used to construct output schema of test */
+ private case class KolmogorovSmirnovTestResult(
+ pValue: Double,
+ statistic: Double)
+
+ private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = {
+ SchemaUtils.checkNumericType(dataset.schema, sampleCol)
+ import dataset.sparkSession.implicits._
+ dataset.select(col(sampleCol).cast("double")).as[Double].rdd
+ }
+
+ /**
+ * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
+ * continuous distribution. By comparing the largest difference between the empirical cumulative
+ * distribution of the sample data and the theoretical distribution we can provide a test for the
+ * the null hypothesis that the sample data comes from that theoretical distribution.
+ *
+ * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test
+ * @param sampleCol Name of sample column in dataset, of any numerical type
+ * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value
+ * @return DataFrame containing the test result for the input sampled data.
+ * This DataFrame will contain a single Row with the following fields:
+ * - `pValue: Double`
+ * - `statistic: Double`
+ */
+ @Since("2.4.0")
+ def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = {
+ val spark = dataset.sparkSession
+
+ val rdd = getSampleRDD(dataset.toDF(), sampleCol)
+ val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf)
+ spark.createDataFrame(Seq(KolmogorovSmirnovTestResult(
+ testResult.pValue, testResult.statistic)))
+ }
+
+ /**
+ * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)`
+ */
+ @Since("2.4.0")
+ def test(
+ dataset: Dataset[_],
+ sampleCol: String,
+ cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = {
+ test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble)
+ }
+
+ /**
+ * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability
+ * distribution equality. Currently supports the normal distribution, taking as parameters
+ * the mean and standard deviation.
+ *
+ * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test
+ * @param sampleCol Name of sample column in dataset, of any numerical type
+ * @param distName a `String` name for a theoretical distribution, currently only support "norm".
+ * @param params `Double*` specifying the parameters to be used for the theoretical distribution.
+ * For "norm" distribution, the parameters includes mean and variance.
+ * @return DataFrame containing the test result for the input sampled data.
+ * This DataFrame will contain a single Row with the following fields:
+ * - `pValue: Double`
+ * - `statistic: Double`
+ */
+ @Since("2.4.0")
+ @varargs
+ def test(
+ dataset: Dataset[_],
+ sampleCol: String, distName: String,
+ params: Double*): DataFrame = {
+ val spark = dataset.sparkSession
+
+ val rdd = getSampleRDD(dataset.toDF(), sampleCol)
+ val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*)
+ spark.createDataFrame(Seq(KolmogorovSmirnovTestResult(
+ testResult.pValue, testResult.statistic)))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 07e98a142b10e..0242bc76698d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -17,15 +17,16 @@
package org.apache.spark.ml.tree
+import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{ImpurityStats,
- InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict}
+import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats,
+ Node => OldNode, Predict => OldPredict}
/**
* Decision tree node interface.
*/
-sealed abstract class Node extends Serializable {
+sealed trait Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
// code into the new API and deprecate the old API. SPARK-3727
@@ -85,35 +86,86 @@ private[ml] object Node {
/**
* Create a new Node from the old Node format, recursively creating child nodes as needed.
*/
- def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ def fromOld(
+ oldNode: OldNode,
+ categoricalFeatures: Map[Int, Int],
+ isClassification: Boolean): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
- new LeafNode(prediction = oldNode.predict.predict,
- impurity = oldNode.impurity, impurityStats = null)
+ if (isClassification) {
+ new ClassificationLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ } else {
+ new RegressionLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ }
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
- new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
- gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
- rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
- split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
+ if (isClassification) {
+ new ClassificationInternalNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true)
+ .asInstanceOf[ClassificationNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true)
+ .asInstanceOf[ClassificationNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
+ } else {
+ new RegressionInternalNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false)
+ .asInstanceOf[RegressionNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false)
+ .asInstanceOf[RegressionNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
+ }
}
}
}
-/**
- * Decision tree leaf node.
- * @param prediction Prediction this node makes
- * @param impurity Impurity measure at this node (for training data)
- */
-class LeafNode private[ml] (
- override val prediction: Double,
- override val impurity: Double,
- override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+@Since("2.4.0")
+sealed trait ClassificationNode extends Node {
+
+ /**
+ * Get count of training examples for specified label in this node
+ * @param label label number in the range [0, numClasses)
+ */
+ @Since("2.4.0")
+ def getLabelCount(label: Int): Double = {
+ require(label >= 0 && label < impurityStats.stats.length,
+ "label should be in the range between 0 (inclusive) " +
+ s"and ${impurityStats.stats.length} (exclusive).")
+ impurityStats.stats(label)
+ }
+}
+
+@Since("2.4.0")
+sealed trait RegressionNode extends Node {
+
+ /** Number of training data points in this node */
+ @Since("2.4.0")
+ def getCount: Double = impurityStats.stats(0)
+
+ /** Sum over training data points of the labels in this node */
+ @Since("2.4.0")
+ def getSum: Double = impurityStats.stats(1)
+
+ /** Sum over training data points of the square of the labels in this node */
+ @Since("2.4.0")
+ def getSumOfSquares: Double = impurityStats.stats(2)
+}
+
+@Since("2.4.0")
+sealed trait LeafNode extends Node {
+
+ /** Prediction this node makes. */
+ def prediction: Double
+
+ def impurity: Double
override def toString: String =
s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -136,32 +188,58 @@ class LeafNode private[ml] (
override private[ml] def maxSplitFeatureIndex(): Int = -1
+}
+
+/**
+ * Decision tree leaf node for classification.
+ */
+@Since("2.4.0")
+class ClassificationLeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ override private[ml] val impurityStats: ImpurityCalculator)
+ extends ClassificationNode with LeafNode {
+
override private[tree] def deepCopy(): Node = {
- new LeafNode(prediction, impurity, impurityStats)
+ new ClassificationLeafNode(prediction, impurity, impurityStats)
}
}
/**
- * Internal Decision Tree node.
- * @param prediction Prediction this node would make if it were a leaf node
- * @param impurity Impurity measure at this node (for training data)
- * @param gain Information gain value. Values less than 0 indicate missing values;
- * this quirk will be removed with future updates.
- * @param leftChild Left-hand child node
- * @param rightChild Right-hand child node
- * @param split Information about the test used to split to the left or right child.
+ * Decision tree leaf node for regression.
*/
-class InternalNode private[ml] (
+@Since("2.4.0")
+class RegressionLeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
- val gain: Double,
- val leftChild: Node,
- val rightChild: Node,
- val split: Split,
- override private[ml] val impurityStats: ImpurityCalculator) extends Node {
+ override private[ml] val impurityStats: ImpurityCalculator)
+ extends RegressionNode with LeafNode {
- // Note to developers: The constructor argument impurityStats should be reconsidered before we
- // make the constructor public. We may be able to improve the representation.
+ override private[tree] def deepCopy(): Node = {
+ new RegressionLeafNode(prediction, impurity, impurityStats)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ */
+@Since("2.4.0")
+sealed trait InternalNode extends Node {
+
+ /**
+ * Information gain value. Values less than 0 indicate missing values;
+ * this quirk will be removed with future updates.
+ */
+ def gain: Double
+
+ /** Left-hand child node */
+ def leftChild: Node
+
+ /** Right-hand child node */
+ def rightChild: Node
+
+ /** Information about the test used to split to the left or right child. */
+ def split: Split
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
@@ -206,11 +284,6 @@ class InternalNode private[ml] (
math.max(split.featureIndex,
math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
}
-
- override private[tree] def deepCopy(): Node = {
- new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(),
- split, impurityStats)
- }
}
private object InternalNode {
@@ -241,6 +314,57 @@ private object InternalNode {
}
}
+/**
+ * Internal Decision Tree node for regression.
+ */
+@Since("2.4.0")
+class ClassificationInternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ override val gain: Double,
+ override val leftChild: ClassificationNode,
+ override val rightChild: ClassificationNode,
+ override val split: Split,
+ override private[ml] val impurityStats: ImpurityCalculator)
+ extends ClassificationNode with InternalNode {
+
+ // Note to developers: The constructor argument impurityStats should be reconsidered before we
+ // make the constructor public. We may be able to improve the representation.
+
+ override private[tree] def deepCopy(): Node = {
+ new ClassificationInternalNode(prediction, impurity, gain,
+ leftChild.deepCopy().asInstanceOf[ClassificationNode],
+ rightChild.deepCopy().asInstanceOf[ClassificationNode],
+ split, impurityStats)
+ }
+}
+
+/**
+ * Internal Decision Tree node for regression.
+ */
+@Since("2.4.0")
+class RegressionInternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ override val gain: Double,
+ override val leftChild: RegressionNode,
+ override val rightChild: RegressionNode,
+ override val split: Split,
+ override private[ml] val impurityStats: ImpurityCalculator)
+ extends RegressionNode with InternalNode {
+
+ // Note to developers: The constructor argument impurityStats should be reconsidered before we
+ // make the constructor public. We may be able to improve the representation.
+
+ override private[tree] def deepCopy(): Node = {
+ new RegressionInternalNode(prediction, impurity, gain,
+ leftChild.deepCopy().asInstanceOf[RegressionNode],
+ rightChild.deepCopy().asInstanceOf[RegressionNode],
+ split, impurityStats)
+ }
+}
+
+
/**
* Version of a node used in learning. This uses vars so that we can modify nodes as we split the
* tree by adding children, etc.
@@ -266,24 +390,53 @@ private[tree] class LearningNode(
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {
+ def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true)
+
+ def toClassificationNode(prune: Boolean = true): ClassificationNode = {
+ toNode(true, prune).asInstanceOf[ClassificationNode]
+ }
+
+ def toRegressionNode(prune: Boolean = true): RegressionNode = {
+ toNode(false, prune).asInstanceOf[RegressionNode]
+ }
+
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
- def toNode: Node = {
- if (leftChild.nonEmpty) {
- assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+ def toNode(isClassification: Boolean, prune: Boolean): Node = {
+
+ if (!leftChild.isEmpty || !rightChild.isEmpty) {
+ assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
- new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
- leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
+ (leftChild.get.toNode(isClassification, prune),
+ rightChild.get.toNode(isClassification, prune)) match {
+ case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction =>
+ if (isClassification) {
+ new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
+ } else {
+ new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
+ }
+ case (l, r) =>
+ if (isClassification) {
+ new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity,
+ stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode],
+ split.get, stats.impurityCalculator)
+ } else {
+ new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
+ l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode],
+ split.get, stats.impurityCalculator)
+ }
+ }
} else {
- if (stats.valid) {
- new LeafNode(stats.impurityCalculator.predict, stats.impurity,
+ // Here we want to keep same behavior with the old mllib.DecisionTreeModel
+ val impurity = if (stats.valid) stats.impurity else -1.0
+ if (isClassification) {
+ new ClassificationLeafNode(stats.impurityCalculator.predict, impurity,
stats.impurityCalculator)
} else {
- // Here we want to keep same behavior with the old mllib.DecisionTreeModel
- new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
+ new RegressionLeafNode(stats.impurityCalculator.predict, impurity,
+ stats.impurityCalculator)
}
-
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index a7c5f489dea86..5b14a63ada4ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -95,7 +95,7 @@ private[spark] class NodeIdCache(
splits: Array[Array[Split]]): Unit = {
if (prevNodeIdsForInstances != null) {
// Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
+ prevNodeIdsForInstances.unpersist(false)
}
prevNodeIdsForInstances = nodeIdsForInstances
@@ -166,9 +166,13 @@ private[spark] class NodeIdCache(
}
}
}
+ if (nodeIdsForInstances != null) {
+ // Unpersist current one if one exists.
+ nodeIdsForInstances.unpersist(false)
+ }
if (prevNodeIdsForInstances != null) {
// Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
+ prevNodeIdsForInstances.unpersist(false)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index acfc6399c553b..905870178e549 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging {
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
+ prune: Boolean = true, // exposed for testing only, real trees are always pruned
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
@@ -107,9 +108,11 @@ private[spark] object RandomForest extends Logging {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
instrumentation.logNumClasses(metadata.numClasses)
+ instrumentation.logNumExamples(metadata.numExamples)
case None =>
logInfo("numFeatures: " + metadata.numFeatures)
logInfo("numClasses: " + metadata.numClasses)
+ logInfo("numExamples: " + metadata.numExamples)
}
// Find the splits and the corresponding bins (interval between the splits) using a sample
@@ -223,22 +226,23 @@ private[spark] object RandomForest extends Logging {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
- strategy.getNumClasses)
+ new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune),
+ numFeatures, strategy.getNumClasses)
}
} else {
topNodes.map { rootNode =>
- new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
+ new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures)
}
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
topNodes.map { rootNode =>
- new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
+ new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures,
strategy.getNumClasses)
}
} else {
- topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
+ topNodes.map(rootNode =>
+ new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures))
}
}
}
@@ -890,13 +894,7 @@ private[spark] object RandomForest extends Logging {
// Sample the input only if there are continuous features.
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
- // Calculate the number of samples for approximate quantile calculation.
- val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
- val fraction = if (requiredSamples < metadata.numExamples) {
- requiredSamples.toDouble / metadata.numExamples
- } else {
- 1.0
- }
+ val fraction = samplesFractionForFindSplits(metadata)
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
@@ -918,8 +916,9 @@ private[spark] object RandomForest extends Logging {
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
input
- .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
- .groupByKey(numPartitions)
+ .flatMap { point =>
+ continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0)
+ }.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
@@ -931,7 +930,8 @@ private[spark] object RandomForest extends Logging {
val numFeatures = metadata.numFeatures
val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
case i if metadata.isContinuous(i) =>
- val split = continuousSplits(i)
+ // some features may contain only zero, so continuousSplits will not have a record
+ val split = continuousSplits.getOrElse(i, Array.empty[Split])
metadata.setNumSplits(i, split.length)
split
@@ -1001,11 +1001,22 @@ private[spark] object RandomForest extends Logging {
} else {
val numSplits = metadata.numSplits(featureIndex)
- // get count for each distinct value
- val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
- case ((m, cnt), x) =>
- (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
+ // get count for each distinct value except zero value
+ val partNumSamples = featureSamples.size
+ val partValueCountMap = scala.collection.mutable.Map[Double, Int]()
+ featureSamples.foreach { x =>
+ partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1
+ }
+
+ // Calculate the expected number of samples for finding splits
+ val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt
+ // add expected zero value count and get complete statistics
+ val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) {
+ partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples))
+ } else {
+ partValueCountMap.toMap
}
+
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -1147,4 +1158,21 @@ private[spark] object RandomForest extends Logging {
3 * totalBins
}
}
+
+ /**
+ * Calculate the subsample fraction for finding splits
+ *
+ * @param metadata decision tree metadata
+ * @return subsample fraction
+ */
+ private def samplesFractionForFindSplits(
+ metadata: DecisionTreeMetadata): Double = {
+ // Calculate the number of samples for approximate quantile calculation.
+ val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
+ if (requiredSamples < metadata.numExamples) {
+ requiredSamples.toDouble / metadata.numExamples
+ } else {
+ 1.0
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 4aa4c3617e7fd..f027b14f1d476 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel {
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
+ case _: LeafNode =>
// do nothing
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
}
}
@@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite {
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
id)
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}")
}
}
@@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite {
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
- sparkSession: SparkSession): Node = {
+ sparkSession: SparkSession, isClassification: Boolean): Node = {
import sparkSession.implicits._
implicit val format = DefaultFormats
@@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).as[NodeData]
- buildTreeFromNodes(data.collect(), impurityType)
+ buildTreeFromNodes(data.collect(), impurityType, isClassification)
}
/**
@@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite {
* @param impurityType Impurity type for this tree
* @return Root node of reconstructed tree
*/
- def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
+ isClassification: Boolean): Node = {
// Load all nodes, sorted by ID.
val nodes = data.sortBy(_.id)
// Sanity checks; could remove
@@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite {
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
- new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild,
- n.split.getSplit, impurityStats)
+ if (isClassification) {
+ new ClassificationInternalNode(n.prediction, n.impurity, n.gain,
+ leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode],
+ n.split.getSplit, impurityStats)
+ } else {
+ new RegressionInternalNode(n.prediction, n.impurity, n.gain,
+ leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode],
+ n.split.getSplit, impurityStats)
+ }
} else {
- new LeafNode(n.prediction, n.impurity, impurityStats)
+ if (isClassification) {
+ new ClassificationLeafNode(n.prediction, n.impurity, impurityStats)
+ } else {
+ new RegressionLeafNode(n.prediction, n.impurity, impurityStats)
+ }
}
finalNodes(n.id) = node
}
@@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite {
path: String,
sql: SparkSession,
className: String,
- treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
+ treeClassName: String,
+ isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
import sql.implicits._
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
@@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite {
val rootNodesRDD: RDD[(Int, Node)] =
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
case (treeID: Int, nodeData: Iterable[NodeData]) =>
- treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(
+ nodeData.toArray, impurityType, isClassification)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
(metadata, treesMetadata.zip(rootNodes), treesWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 81b6222acc7ce..00157fe63af41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -21,6 +21,7 @@ import java.util.Locale
import scala.util.Try
+import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
-
- /* TODO: Add this doc when we add this param. SPARK-7132
- * Threshold for stopping early when runWithValidation is used.
- * If the error rate on the validation input changes by less than the validationTol,
- * then learning will stop early (before [[numIterations]]).
- * This parameter is ignored when run is used.
- * (default = 1e-5)
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize
+ with HasValidationIndicatorCol {
+
+ /**
+ * Threshold for stopping early when fit with validation is used.
+ * (This parameter is ignored when fit without validation is used.)
+ * The decision to stop early is decided based on this logic:
+ * If the current loss on the validation set is greater than 0.01, the diff
+ * of validation error is compared to relative tolerance which is
+ * validationTol * (current loss on the validation set).
+ * If the current loss on the validation set is less than or equal to 0.01,
+ * the diff of validation error is compared to absolute tolerance which is
+ * validationTol * 0.01.
* @group param
+ * @see validationIndicatorCol
*/
- // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
- // validationTol -> 1e-5
+ @Since("2.4.0")
+ final val validationTol: DoubleParam = new DoubleParam(this, "validationTol",
+ "Threshold for stopping early when fit with validation is used." +
+ "If the error rate on the validation input changes by less than the validationTol," +
+ "then learning will stop early (before `maxIter`)." +
+ "This parameter is ignored when fit without validation is used.",
+ ParamValidators.gtEq(0.0)
+ )
+
+ /** @group getParam */
+ @Since("2.4.0")
+ final def getValidationTol: Double = $(validationTol)
/**
* @deprecated This method is deprecated and will be removed in 3.0.0.
@@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
def setStepSize(value: Double): this.type = set(stepSize, value)
- setDefault(maxIter -> 20, stepSize -> 0.1)
+ setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01)
setDefault(featureSubsetStrategy -> "all")
@@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
// NOTE: The old API does not support "seed" so we ignore it.
- new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol)
}
/** Get old Gradient Boosting Loss type */
@@ -579,7 +596,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
+ convertToOldLossType(getLossType)
+ }
+
+ private[ml] def convertToOldLossType(loss: String): OldLoss = {
+ loss match {
case "squared" => OldSquaredError
case "absolute" => OldAbsoluteError
case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index a0b507d2e718c..f327f37bad204 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
- logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")
// Fit models in a Future for training in parallel
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
@@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
- logDebug(s"Got metric $metric for model trained with $paramMap.")
+ instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
@@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
foldMetrics
}.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
- logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+ instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
- logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
- logInfo(s"Best cross-validation metric: $bestMetric.")
+ instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ instr.logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics)
@@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
- DefaultParamsReader.getAndSetParams(cv, metadata,
- skipParams = Option(List("estimatorParamMaps")))
+ metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps")))
cv
}
}
@@ -270,6 +269,17 @@ class CrossValidatorModel private[ml] (
this
}
+ // A Python-friendly auxiliary method
+ private[tuning] def setSubModels(subModels: JList[JList[Model[_]]])
+ : CrossValidatorModel = {
+ _subModels = if (subModels != null) {
+ Some(subModels.asScala.toArray.map(_.asScala.toArray))
+ } else {
+ None
+ }
+ this
+ }
+
/**
* @return submodels represented in two dimension array. The index of outer array is the
* fold index, and the index of inner array corresponds to the ordering of
@@ -413,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
- DefaultParamsReader.getAndSetParams(model, metadata,
- skipParams = Option(List("estimatorParamMaps")))
+ metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps")))
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 88ff0dfd75e96..14d6a69c36747 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
} else None
// Fit models in a Future for training in parallel
- logDebug(s"Train split with multiple sets of parameters.")
+ instr.logDebug(s"Train split with multiple sets of parameters.")
val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
@@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
- logDebug(s"Got metric $metric for model trained with $paramMap.")
+ instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
@@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
trainingDataset.unpersist()
validationDataset.unpersist()
- logInfo(s"Train validation split metrics: ${metrics.toSeq}")
+ instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
- logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
- logInfo(s"Best train validation split metric: $bestMetric.")
+ instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ instr.logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics)
@@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
- DefaultParamsReader.getAndSetParams(tvs, metadata,
- skipParams = Option(List("estimatorParamMaps")))
+ metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps")))
tvs
}
}
@@ -262,6 +261,17 @@ class TrainValidationSplitModel private[ml] (
this
}
+ // A Python-friendly auxiliary method
+ private[tuning] def setSubModels(subModels: JList[Model[_]])
+ : TrainValidationSplitModel = {
+ _subModels = if (subModels != null) {
+ Some(subModels.asScala.toArray)
+ } else {
+ None
+ }
+ this
+ }
+
/**
* @return submodels represented in array. The index of array corresponds to the ordering of
* estimatorParamMaps
@@ -396,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
- DefaultParamsReader.getAndSetParams(model, metadata,
- skipParams = Option(List("estimatorParamMaps")))
+ metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps")))
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
new file mode 100644
index 0000000000000..6af4b3ebc2cc2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Column, Dataset, Row}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
+
+
+private[spark] object DatasetUtils {
+
+ /**
+ * Cast a column in a Dataset to Vector type.
+ *
+ * The supported data types of the input column are
+ * - Vector
+ * - float/double type Array.
+ *
+ * Note: The returned column does not have Metadata.
+ *
+ * @param dataset input DataFrame
+ * @param colName column name.
+ * @return Vector column
+ */
+ def columnToVector(dataset: Dataset[_], colName: String): Column = {
+ val columnDataType = dataset.schema(colName).dataType
+ columnDataType match {
+ case _: VectorUDT => col(colName)
+ case fdt: ArrayType =>
+ val transferUDF = fdt.elementType match {
+ case _: FloatType => udf(f = (vector: Seq[Float]) => {
+ val inputArray = Array.fill[Double](vector.size)(0.0)
+ vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble)
+ Vectors.dense(inputArray)
+ })
+ case _: DoubleType => udf((vector: Seq[Double]) => {
+ Vectors.dense(vector.toArray)
+ })
+ case other =>
+ throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector")
+ }
+ transferUDF(col(colName))
+ case other =>
+ throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
+ }
+ }
+
+ def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = {
+ dataset.select(columnToVector(dataset, colName))
+ .rdd.map {
+ case Row(point: Vector) => OldVectors.fromML(point)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 7c46f45c59717..11f46eb9e4359 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -17,7 +17,9 @@
package org.apache.spark.ml.util
-import java.util.concurrent.atomic.AtomicLong
+import java.util.UUID
+
+import scala.reflect.ClassTag
import org.json4s._
import org.json4s.JsonDSL._
@@ -28,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
+import org.apache.spark.util.Utils
/**
* A small wrapper that defines a training session for an estimator, and some methods to log
@@ -40,11 +43,14 @@ import org.apache.spark.sql.Dataset
* @tparam E the type of the estimator
*/
private[spark] class Instrumentation[E <: Estimator[_]] private (
- estimator: E, dataset: RDD[_]) extends Logging {
+ val estimator: E,
+ val dataset: RDD[_]) extends Logging {
- private val id = Instrumentation.counter.incrementAndGet()
+ private val id = UUID.randomUUID()
private val prefix = {
- val className = estimator.getClass.getSimpleName
+ // estimator.getClass.getSimpleName can cause Malformed class name error,
+ // call safer `Utils.getSimpleName` instead
+ val className = Utils.getSimpleName(estimator.getClass)
s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
}
@@ -56,12 +62,38 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}
/**
- * Logs a message with a prefix that uniquely identifies the training session.
+ * Logs a debug message with a prefix that uniquely identifies the training session.
+ */
+ override def logDebug(msg: => String): Unit = {
+ super.logDebug(prefix + msg)
+ }
+
+ /**
+ * Logs a warning message with a prefix that uniquely identifies the training session.
+ */
+ override def logWarning(msg: => String): Unit = {
+ super.logWarning(prefix + msg)
+ }
+
+ /**
+ * Logs a error message with a prefix that uniquely identifies the training session.
*/
- def log(msg: String): Unit = {
- logInfo(prefix + msg)
+ override def logError(msg: => String): Unit = {
+ super.logError(prefix + msg)
}
+ /**
+ * Logs an info message with a prefix that uniquely identifies the training session.
+ */
+ override def logInfo(msg: => String): Unit = {
+ super.logInfo(prefix + msg)
+ }
+
+ /**
+ * Alias for logInfo, see above.
+ */
+ def log(msg: String): Unit = logInfo(msg)
+
/**
* Logs the value of the given parameters for the estimator being used in this session.
*/
@@ -77,11 +109,15 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
}
def logNumFeatures(num: Long): Unit = {
- log(compact(render("numFeatures" -> num)))
+ logNamedValue(Instrumentation.loggerTags.numFeatures, num)
}
def logNumClasses(num: Long): Unit = {
- log(compact(render("numClasses" -> num)))
+ logNamedValue(Instrumentation.loggerTags.numClasses, num)
+ }
+
+ def logNumExamples(num: Long): Unit = {
+ logNamedValue(Instrumentation.loggerTags.numExamples, num)
}
/**
@@ -95,6 +131,23 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
log(compact(render(name -> value)))
}
+ def logNamedValue(name: String, value: Double): Unit = {
+ log(compact(render(name -> value)))
+ }
+
+ def logNamedValue(name: String, value: Array[String]): Unit = {
+ log(compact(render(name -> compact(render(value.toSeq)))))
+ }
+
+ def logNamedValue(name: String, value: Array[Long]): Unit = {
+ log(compact(render(name -> compact(render(value.toSeq)))))
+ }
+
+ def logNamedValue(name: String, value: Array[Double]): Unit = {
+ log(compact(render(name -> compact(render(value.toSeq)))))
+ }
+
+
/**
* Logs the successful completion of the training session.
*/
@@ -107,7 +160,14 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
* Some common methods for logging information about a training session.
*/
private[spark] object Instrumentation {
- private val counter = new AtomicLong(0)
+
+ object loggerTags {
+ val numFeatures = "numFeatures"
+ val numClasses = "numClasses"
+ val numExamples = "numExamples"
+ val meanOfLabels = "meanOfLabels"
+ val varianceOfLabels = "varianceOfLabels"
+ }
/**
* Creates an instrumentation object for a training session.
@@ -126,3 +186,56 @@ private[spark] object Instrumentation {
}
}
+
+/**
+ * A small wrapper that contains an optional `Instrumentation` object.
+ * Provide some log methods, if the containing `Instrumentation` object is defined,
+ * will log via it, otherwise will log via common logger.
+ */
+private[spark] class OptionalInstrumentation private(
+ val instrumentation: Option[Instrumentation[_ <: Estimator[_]]],
+ val className: String) extends Logging {
+
+ protected override def logName: String = className
+
+ override def logInfo(msg: => String) {
+ instrumentation match {
+ case Some(instr) => instr.logInfo(msg)
+ case None => super.logInfo(msg)
+ }
+ }
+
+ override def logWarning(msg: => String) {
+ instrumentation match {
+ case Some(instr) => instr.logWarning(msg)
+ case None => super.logWarning(msg)
+ }
+ }
+
+ override def logError(msg: => String) {
+ instrumentation match {
+ case Some(instr) => instr.logError(msg)
+ case None => super.logError(msg)
+ }
+ }
+}
+
+private[spark] object OptionalInstrumentation {
+
+ /**
+ * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object.
+ */
+ def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = {
+ new OptionalInstrumentation(Some(instr),
+ instr.estimator.getClass.getName.stripSuffix("$"))
+ }
+
+ /**
+ * Creates an `OptionalInstrumentation` object from a `Class` object.
+ * The created `OptionalInstrumentation` object will log messages via common logger and use the
+ * specified class name as logger name.
+ */
+ def create(clazz: Class[_]): OptionalInstrumentation = {
+ new OptionalInstrumentation(None, clazz.getName.stripSuffix("$"))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index a616907800969..72a60e04360d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -18,9 +18,11 @@
package org.apache.spark.ml.util
import java.io.IOException
-import java.util.Locale
+import java.util.{Locale, ServiceLoader}
+import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.util.{Failure, Success, Try}
import org.apache.hadoop.fs.Path
import org.json4s._
@@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.SparkContext
-import org.apache.spark.annotation.{DeveloperApi, Since}
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
@@ -37,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Utils, VersionUtils}
/**
* Trait for `MLWriter` and `MLReader`.
@@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite {
}
/**
- * Abstract class for utility classes that can save ML instances.
+ * Abstract class to be implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a save call is made.
+ *
+ * Must have a valid zero argument constructor which will be called to instantiate.
+ *
+ * @since 2.4.0
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+trait MLWriterFormat {
+ /**
+ * Function to write the provided pipeline stage out.
+ *
+ * @param path The path to write the result out to.
+ * @param session SparkSession associated with the write request.
+ * @param optionMap User provided options stored as strings.
+ * @param stage The pipeline stage to be saved.
+ */
+ @Since("2.4.0")
+ def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
+ stage: PipelineStage): Unit
+}
+
+/**
+ * ML export formats for should implement this trait so that users can specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a save call is made.
+ *
+ * @since 2.4.0
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+trait MLFormatRegister extends MLWriterFormat {
+ /**
+ * The string that represents the format that this format provider uses. This is, along with
+ * stageName, is overridden by children to provide a nice alias for the writer. For example:
+ *
+ * {{{
+ * override def format(): String =
+ * "pmml"
+ * }}}
+ * Indicates that this format is capable of saving a pmml model.
+ *
+ * Must have a valid zero argument constructor which will be called to instantiate.
+ *
+ * Format discovery is done using a ServiceLoader so make sure to list your format in
+ * META-INF/services.
+ * @since 2.4.0
+ */
+ @Since("2.4.0")
+ def format(): String
+
+ /**
+ * The string that represents the stage type that this writer supports. This is, along with
+ * format, is overridden by children to provide a nice alias for the writer. For example:
+ *
+ * {{{
+ * override def stageName(): String =
+ * "org.apache.spark.ml.regression.LinearRegressionModel"
+ * }}}
+ * Indicates that this format is capable of saving Spark's own PMML model.
+ *
+ * Format discovery is done using a ServiceLoader so make sure to list your format in
+ * META-INF/services.
+ * @since 2.4.0
+ */
+ @Since("2.4.0")
+ def stageName(): String
+
+ private[ml] def shortName(): String = s"${format()}+${stageName()}"
+}
+
+/**
+ * Abstract class for utility classes that can save ML instances in Spark's internal format.
*/
@Since("1.6.0")
abstract class MLWriter extends BaseReadWrite with Logging {
@@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging {
@Since("1.6.0")
protected def saveImpl(path: String): Unit
+ /**
+ * Overwrites if the output path already exists.
+ */
+ @Since("1.6.0")
+ def overwrite(): this.type = {
+ shouldOverwrite = true
+ this
+ }
+
/**
* Map to store extra options for this writer.
*/
@@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging {
this
}
+ // override for Java compatibility
+ @Since("1.6.0")
+ override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
+
+ // override for Java compatibility
+ @Since("1.6.0")
+ override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+@InterfaceStability.Unstable
+@Since("2.4.0")
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+ private var source: String = "internal"
+
/**
- * Overwrites if the output path already exists.
+ * Specifies the format of ML export (e.g. "pmml", "internal", or
+ * the fully qualified class name for export).
*/
- @Since("1.6.0")
- def overwrite(): this.type = {
- shouldOverwrite = true
+ @Since("2.4.0")
+ def format(source: String): this.type = {
+ this.source = source
this
}
+ /**
+ * Dispatches the save to the correct MLFormat.
+ */
+ @Since("2.4.0")
+ @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+ @throws[SparkException]("If multiple sources for a given short name format are found.")
+ override protected def saveImpl(path: String): Unit = {
+ val loader = Utils.getContextOrSparkClassLoader
+ val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader)
+ val stageName = stage.getClass.getName
+ val targetName = s"$source+$stageName"
+ val formats = serviceLoader.asScala.toList
+ val shortNames = formats.map(_.shortName())
+ val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+ // requested name did not match any given registered alias
+ case Nil =>
+ Try(loader.loadClass(source)) match {
+ case Success(writer) =>
+ // Found the ML writer using the fully qualified path
+ writer
+ case Failure(error) =>
+ throw new SparkException(
+ s"Could not load requested format $source for $stageName ($targetName) had $formats" +
+ s"supporting $shortNames", error)
+ }
+ case head :: Nil =>
+ head.getClass
+ case _ =>
+ // Multiple sources
+ throw new SparkException(
+ s"Multiple writers found for $source+$stageName, try using the class name of the writer")
+ }
+ if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+ val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+ writer.write(path, sparkSession, optionMap, stage)
+ } else {
+ throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
+ }
+ }
+
// override for Java compatibility
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
@@ -162,6 +306,19 @@ trait MLWritable {
def save(path: String): Unit = write.save(path)
}
+/**
+ * Trait for classes that provide `GeneralMLWriter`.
+ */
+@Since("2.4.0")
+@InterfaceStability.Unstable
+trait GeneralMLWritable extends MLWritable {
+ /**
+ * Returns an `MLWriter` instance for this ML instance.
+ */
+ @Since("2.4.0")
+ override def write: GeneralMLWriter
+}
+
/**
* :: DeveloperApi ::
*
@@ -264,6 +421,7 @@ private[ml] object DefaultParamsWriter {
* - timestamp
* - sparkVersion
* - uid
+ * - defaultParamMap
* - paramMap
* - (optionally, extra metadata)
*
@@ -296,15 +454,20 @@ private[ml] object DefaultParamsWriter {
paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
- val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
+ val params = instance.paramMap.toSeq
+ val defaultParams = instance.defaultParamMap.toSeq
val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList))
+ val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList)
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
- ("paramMap" -> jsonParams)
+ ("paramMap" -> jsonParams) ~
+ ("defaultParamMap" -> jsonDefaultParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
@@ -331,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] {
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
- DefaultParamsReader.getAndSetParams(instance, metadata)
+ metadata.getAndSetParams(instance)
instance.asInstanceOf[T]
}
}
@@ -342,6 +505,8 @@ private[ml] object DefaultParamsReader {
* All info from metadata file.
*
* @param params paramMap, as a `JValue`
+ * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4,
+ * this is `JNothing`.
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
*/
@@ -351,27 +516,90 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
+ defaultParams: JValue,
metadata: JValue,
metadataJson: String) {
+
+ private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
+ params match {
+ case JObject(pairs) => pairs
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: $metadataJson.")
+ }
+ }
+
/**
* Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
* This can be useful for getting a Param value before an instance of `Params`
- * is available.
+ * is available. This will look up `params` first, if not existing then looking up
+ * `defaultParams`.
*/
def getParamValue(paramName: String): JValue = {
implicit val format = DefaultFormats
- params match {
+
+ // Looking up for `params` first.
+ var pairs = getValueFromParams(params)
+ var foundPairs = pairs.filter { case (pName, jsonValue) =>
+ pName == paramName
+ }
+ if (foundPairs.length == 0) {
+ // Looking up for `defaultParams` then.
+ pairs = getValueFromParams(defaultParams)
+ foundPairs = pairs.filter { case (pName, jsonValue) =>
+ pName == paramName
+ }
+ }
+ assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
+ s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
+
+ foundPairs.map(_._2).head
+ }
+
+ /**
+ * Extract Params from metadata, and set them in the instance.
+ * This works if all Params (except params included by `skipParams` list) implement
+ * [[org.apache.spark.ml.param.Param.jsonDecode()]].
+ *
+ * @param skipParams The params included in `skipParams` won't be set. This is useful if some
+ * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
+ * and need special handling.
+ */
+ def getAndSetParams(
+ instance: Params,
+ skipParams: Option[List[String]] = None): Unit = {
+ setParams(instance, skipParams, isDefault = false)
+
+ // For metadata file prior to Spark 2.4, there is no default section.
+ val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion)
+ if (major > 2 || (major == 2 && minor >= 4)) {
+ setParams(instance, skipParams, isDefault = true)
+ }
+ }
+
+ private def setParams(
+ instance: Params,
+ skipParams: Option[List[String]],
+ isDefault: Boolean): Unit = {
+ implicit val format = DefaultFormats
+ val paramsToSet = if (isDefault) defaultParams else params
+ paramsToSet match {
case JObject(pairs) =>
- val values = pairs.filter { case (pName, jsonValue) =>
- pName == paramName
- }.map(_._2)
- assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
- s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
- values.head
+ pairs.foreach { case (paramName, jsonValue) =>
+ if (skipParams == None || !skipParams.get.contains(paramName)) {
+ val param = instance.getParam(paramName)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ if (isDefault) {
+ Params.setDefault(instance, param, value)
+ } else {
+ instance.set(param, value)
+ }
+ }
+ }
case _ =>
throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: $metadataJson.")
+ s"Cannot recognize JSON metadata: ${metadataJson}.")
}
}
}
@@ -404,43 +632,14 @@ private[ml] object DefaultParamsReader {
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
+ val defaultParams = metadata \ "defaultParamMap"
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
- Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
- }
-
- /**
- * Extract Params from metadata, and set them in the instance.
- * This works if all Params (except params included by `skipParams` list) implement
- * [[org.apache.spark.ml.param.Param.jsonDecode()]].
- *
- * @param skipParams The params included in `skipParams` won't be set. This is useful if some
- * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
- * and need special handling.
- * TODO: Move to [[Metadata]] method
- */
- def getAndSetParams(
- instance: Params,
- metadata: Metadata,
- skipParams: Option[List[String]] = None): Unit = {
- implicit val format = DefaultFormats
- metadata.params match {
- case JObject(pairs) =>
- pairs.foreach { case (paramName, jsonValue) =>
- if (skipParams == None || !skipParams.get.contains(paramName)) {
- val param = instance.getParam(paramName)
- val value = param.jsonDecode(compact(render(jsonValue)))
- instance.set(param, value)
- }
- }
- case _ =>
- throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
- }
+ Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 334410c9620de..d9a3f85ef9a24 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,8 @@
package org.apache.spark.ml.util
-import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
+import org.apache.spark.ml.linalg.VectorUDT
+import org.apache.spark.sql.types._
/**
@@ -101,4 +102,17 @@ private[spark] object SchemaUtils {
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
StructType(schema.fields :+ col)
}
+
+ /**
+ * Check whether the given column in the schema is one of the supporting vector type: Vector,
+ * Array[Float]. Array[Double]
+ * @param schema input schema
+ * @param colName column name
+ */
+ def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = {
+ val typeCandidates = List( new VectorUDT,
+ new ArrayType(DoubleType, false),
+ new ArrayType(FloatType, false))
+ checkColumnTypes(schema, colName, typeCandidates)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index b32d3f252ae59..db3f074ecfbac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable {
data: JavaRDD[java.lang.Iterable[Any]],
minSupport: Double,
numPartitions: Int): FPGrowthModel[Any] = {
- val fpg = new FPGrowth()
- .setMinSupport(minSupport)
- .setNumPartitions(numPartitions)
-
+ val fpg = new FPGrowth(minSupport, numPartitions)
val model = fpg.run(data.rdd.map(_.asScala.toArray))
new FPGrowthModelWrapper(model)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 2221f4c0edc17..98af487306dcc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -57,7 +57,8 @@ class BisectingKMeans private (
private var k: Int,
private var maxIterations: Int,
private var minDivisibleClusterSize: Double,
- private var seed: Long) extends Logging {
+ private var seed: Long,
+ private var distanceMeasure: String) extends Logging {
import BisectingKMeans._
@@ -65,7 +66,7 @@ class BisectingKMeans private (
* Constructs with the default configuration
*/
@Since("1.6.0")
- def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##)
+ def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN)
/**
* Sets the desired number of leaf clusters (default: 4).
@@ -134,6 +135,22 @@ class BisectingKMeans private (
@Since("1.6.0")
def getSeed: Long = this.seed
+ /**
+ * The distance suite used by the algorithm.
+ */
+ @Since("2.4.0")
+ def getDistanceMeasure: String = distanceMeasure
+
+ /**
+ * Set the distance suite used by the algorithm.
+ */
+ @Since("2.4.0")
+ def setDistanceMeasure(distanceMeasure: String): this.type = {
+ DistanceMeasure.validateDistanceMeasure(distanceMeasure)
+ this.distanceMeasure = distanceMeasure
+ this
+ }
+
/**
* Runs the bisecting k-means algorithm.
* @param input RDD of vectors
@@ -147,11 +164,13 @@ class BisectingKMeans private (
}
val d = input.map(_.size).first()
logInfo(s"Feature dimension: $d.")
+
+ val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
// Compute and cache vector norms for fast distance computation.
val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK)
val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) }
var assignments = vectors.map(v => (ROOT_INDEX, v))
- var activeClusters = summarize(d, assignments)
+ var activeClusters = summarize(d, assignments, dMeasure)
val rootSummary = activeClusters(ROOT_INDEX)
val n = rootSummary.size
logInfo(s"Number of points: $n.")
@@ -184,24 +203,25 @@ class BisectingKMeans private (
val divisibleIndices = divisibleClusters.keys.toSet
logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.")
var newClusterCenters = divisibleClusters.flatMap { case (index, summary) =>
- val (left, right) = splitCenter(summary.center, random)
+ val (left, right) = splitCenter(summary.center, random, dMeasure)
Iterator((leftChildIndex(index), left), (rightChildIndex(index), right))
}.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map
var newClusters: Map[Long, ClusterSummary] = null
var newAssignments: RDD[(Long, VectorWithNorm)] = null
for (iter <- 0 until maxIterations) {
- newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters)
+ newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters,
+ dMeasure)
.filter { case (index, _) =>
divisibleIndices.contains(parentIndex(index))
}
- newClusters = summarize(d, newAssignments)
+ newClusters = summarize(d, newAssignments, dMeasure)
newClusterCenters = newClusters.mapValues(_.center).map(identity)
}
if (preIndices != null) {
preIndices.unpersist(false)
}
preIndices = indices
- indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
+ indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys
.persist(StorageLevel.MEMORY_AND_DISK)
assignments = indices.zip(vectors)
inactiveClusters ++= activeClusters
@@ -222,8 +242,8 @@ class BisectingKMeans private (
}
norms.unpersist(false)
val clusters = activeClusters ++ inactiveClusters
- val root = buildTree(clusters)
- new BisectingKMeansModel(root)
+ val root = buildTree(clusters, dMeasure)
+ new BisectingKMeansModel(root, this.distanceMeasure)
}
/**
@@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable {
*/
private def summarize(
d: Int,
- assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = {
- assignments.aggregateByKey(new ClusterSummaryAggregator(d))(
+ assignments: RDD[(Long, VectorWithNorm)],
+ distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = {
+ assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))(
seqOp = (agg, v) => agg.add(v),
combOp = (agg1, agg2) => agg1.merge(agg2)
).mapValues(_.summary)
@@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable {
* Cluster summary aggregator.
* @param d feature dimension
*/
- private class ClusterSummaryAggregator(val d: Int) extends Serializable {
+ private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure)
+ extends Serializable {
private var n: Long = 0L
private val sum: Vector = Vectors.zeros(d)
private var sumSq: Double = 0.0
@@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable {
n += 1L
// TODO: use a numerically stable approach to estimate cost
sumSq += v.norm * v.norm
- BLAS.axpy(1.0, v.vector, sum)
+ distanceMeasure.updateClusterSum(v, sum)
this
}
@@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable {
def merge(other: ClusterSummaryAggregator): this.type = {
n += other.n
sumSq += other.sumSq
- BLAS.axpy(1.0, other.sum, sum)
+ distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum)
this
}
/** Returns the summary. */
def summary: ClusterSummary = {
- val mean = sum.copy
- if (n > 0L) {
- BLAS.scal(1.0 / n, mean)
- }
- val center = new VectorWithNorm(mean)
- val cost = math.max(sumSq - n * center.norm * center.norm, 0.0)
- new ClusterSummary(n, center, cost)
+ val center = distanceMeasure.centroid(sum.copy, n)
+ val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq)
+ ClusterSummary(n, center, cost)
}
}
@@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable {
*/
private def splitCenter(
center: VectorWithNorm,
- random: Random): (VectorWithNorm, VectorWithNorm) = {
+ random: Random,
+ distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = {
val d = center.vector.size
val norm = center.norm
val level = 1e-4 * norm
val noise = Vectors.dense(Array.fill(d)(random.nextDouble()))
- val left = center.vector.copy
- BLAS.axpy(-level, noise, left)
- val right = center.vector.copy
- BLAS.axpy(level, noise, right)
- (new VectorWithNorm(left), new VectorWithNorm(right))
+ distanceMeasure.symmetricCentroids(level, noise, center.vector)
}
/**
@@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable {
private def updateAssignments(
assignments: RDD[(Long, VectorWithNorm)],
divisibleIndices: Set[Long],
- newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = {
+ newClusterCenters: Map[Long, VectorWithNorm],
+ distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
- val newClusterChildren = children.filter(newClusterCenters.contains(_))
+ val newClusterChildren = children.filter(newClusterCenters.contains)
+ val newClusterChildrenCenterToId =
+ newClusterChildren.map(id => newClusterCenters(id) -> id).toMap
+ val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray
if (newClusterChildren.nonEmpty) {
- val selected = newClusterChildren.minBy { child =>
- EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
- }
- (selected, v)
+ val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1
+ val center = newClusterChildrenCenters(selected)
+ val id = newClusterChildrenCenterToId(center)
+ (id, v)
} else {
(index, v)
}
@@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable {
* @param clusters a map from cluster indices to corresponding cluster summaries
* @return the root node of the clustering tree
*/
- private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = {
+ private def buildTree(
+ clusters: Map[Long, ClusterSummary],
+ distanceMeasure: DistanceMeasure): ClusteringTreeNode = {
var leafIndex = 0
var internalIndex = -1
@@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
- val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
- val height = math.sqrt(indexes.map { childIndex =>
- EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
- }.max)
- val children = indexes.map(buildSubTree(_)).toArray
+ val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains)
+ val height = indexes.map { childIndex =>
+ distanceMeasure.distance(center, clusters(childIndex).center)
+ }.max
+ val children = indexes.map(buildSubTree).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
@@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] (
def center: Vector = centerWithNorm.vector
/** Predicts the leaf cluster node index that the input point belongs to. */
- def predict(point: Vector): Int = {
- val (index, _) = predict(new VectorWithNorm(point))
+ def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = {
+ val (index, _) = predict(new VectorWithNorm(point), distanceMeasure)
index
}
/** Returns the full prediction path from root to leaf. */
- def predictPath(point: Vector): Array[ClusteringTreeNode] = {
- predictPath(new VectorWithNorm(point)).toArray
+ def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = {
+ predictPath(new VectorWithNorm(point), distanceMeasure).toArray
}
/** Returns the full prediction path from root to leaf. */
- private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = {
+ private def predictPath(
+ pointWithNorm: VectorWithNorm,
+ distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = {
if (isLeaf) {
this :: Nil
} else {
val selected = children.minBy { child =>
- EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
+ distanceMeasure.distance(child.centerWithNorm, pointWithNorm)
}
- selected :: selected.predictPath(pointWithNorm)
+ selected :: selected.predictPath(pointWithNorm, distanceMeasure)
}
}
/**
- * Computes the cost (squared distance to the predicted leaf cluster center) of the input point.
+ * Computes the cost of the input point.
*/
- def computeCost(point: Vector): Double = {
- val (_, cost) = predict(new VectorWithNorm(point))
+ def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = {
+ val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure)
cost
}
/**
* Predicts the cluster index and the cost of the input point.
*/
- private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
- predict(pointWithNorm,
- EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
+ private def predict(
+ pointWithNorm: VectorWithNorm,
+ distanceMeasure: DistanceMeasure): (Int, Double) = {
+ predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure)
}
/**
@@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] (
* @return (predicted leaf cluster index, cost)
*/
@tailrec
- private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = {
+ private def predict(
+ pointWithNorm: VectorWithNorm,
+ cost: Double,
+ distanceMeasure: DistanceMeasure): (Int, Double) = {
if (isLeaf) {
(index, cost)
} else {
val (selectedChild, minCost) = children.map { child =>
- (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
+ (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm))
}.minBy(_._2)
- selectedChild.predict(pointWithNorm, minCost)
+ selectedChild.predict(pointWithNorm, minCost, distanceMeasure)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 633bda6aac804..9d115afcea75d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession}
*/
@Since("1.6.0")
class BisectingKMeansModel private[clustering] (
- private[clustering] val root: ClusteringTreeNode
+ private[clustering] val root: ClusteringTreeNode,
+ @Since("2.4.0") val distanceMeasure: String
) extends Serializable with Saveable with Logging {
+ @Since("1.6.0")
+ def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN)
+
+ private val distanceMeasureInstance: DistanceMeasure =
+ DistanceMeasure.decodeFromString(distanceMeasure)
+
/**
* Leaf cluster centers.
*/
@@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def predict(point: Vector): Int = {
- root.predict(point)
+ root.predict(point, distanceMeasureInstance)
}
/**
@@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def predict(points: RDD[Vector]): RDD[Int] = {
- points.map { p => root.predict(p) }
+ points.map { p => root.predict(p, distanceMeasureInstance) }
}
/**
@@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def computeCost(point: Vector): Double = {
- root.computeCost(point)
+ root.computeCost(point, distanceMeasureInstance)
}
/**
@@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def computeCost(data: RDD[Vector]): Double = {
- data.map(root.computeCost).sum()
+ data.map(root.computeCost(_, distanceMeasureInstance)).sum()
}
/**
@@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
@Since("2.0.0")
override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
- val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
- implicit val formats = DefaultFormats
- val rootId = (metadata \ "rootId").extract[Int]
- val classNameV1_0 = SaveLoadV1_0.thisClassName
+ val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path)
(loadedClassName, formatVersion) match {
- case (classNameV1_0, "1.0") =>
- val model = SaveLoadV1_0.load(sc, path, rootId)
+ case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) =>
+ val model = SaveLoadV1_0.load(sc, path)
+ model
+ case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) =>
+ val model = SaveLoadV1_0.load(sc, path)
model
case _ => throw new Exception(
s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $formatVersion). Supported:\n" +
- s" ($classNameV1_0, 1.0)")
+ s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" +
+ s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})")
}
}
@@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
}
+ private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+ if (node.children.isEmpty) {
+ Array(node)
+ } else {
+ node.children.flatMap(getNodes) ++ Array(node)
+ }
+ }
+
+ private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+ val root = nodes(rootId)
+ if (root.children.isEmpty) {
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, new Array[ClusteringTreeNode](0))
+ } else {
+ val children = root.children.map(c => buildTree(c, nodes))
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, children.toArray)
+ }
+ }
+
private[clustering] object SaveLoadV1_0 {
- private val thisFormatVersion = "1.0"
+ private[clustering] val thisFormatVersion = "1.0"
private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
@@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
}
- private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
- if (node.children.isEmpty) {
- Array(node)
- } else {
- node.children.flatMap(getNodes(_)) ++ Array(node)
- }
- }
-
- def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+ def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+ implicit val formats: DefaultFormats = DefaultFormats
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+ val rootId = (metadata \ "rootId").extract[Int]
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val rows = spark.read.parquet(Loader.dataPath(path))
Loader.checkSchema[Data](rows.schema)
val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
val rootNode = buildTree(rootId, nodes)
- new BisectingKMeansModel(rootNode)
+ new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN)
}
+ }
+
+ private[clustering] object SaveLoadV2_0 {
+ private[clustering] val thisFormatVersion = "2.0"
- private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
- val root = nodes.get(rootId).get
- if (root.children.isEmpty) {
- new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
- root.cost, root.height, new Array[ClusteringTreeNode](0))
- } else {
- val children = root.children.map(c => buildTree(c, nodes))
- new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
- root.cost, root.height, children.toArray)
- }
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+ def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+ ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val data = getNodes(model.root).map(node => Data(node.index, node.size,
+ node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+ node.children.map(_.index)))
+ spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+ implicit val formats: DefaultFormats = DefaultFormats
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+ val rootId = (metadata \ "rootId").extract[Int]
+ val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
+ val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+ val rows = spark.read.parquet(Loader.dataPath(path))
+ Loader.checkSchema[Data](rows.schema)
+ val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+ val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+ val rootNode = buildTree(rootId, nodes)
+ new BisectingKMeansModel(rootNode, distanceMeasure)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
new file mode 100644
index 0000000000000..683360efabc76
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
+import org.apache.spark.mllib.util.MLUtils
+
+private[spark] abstract class DistanceMeasure extends Serializable {
+
+ /**
+ * @return the index of the closest center to the given point, as well as the cost.
+ */
+ def findClosest(
+ centers: TraversableOnce[VectorWithNorm],
+ point: VectorWithNorm): (Int, Double) = {
+ var bestDistance = Double.PositiveInfinity
+ var bestIndex = 0
+ var i = 0
+ centers.foreach { center =>
+ val currentDistance = distance(center, point)
+ if (currentDistance < bestDistance) {
+ bestDistance = currentDistance
+ bestIndex = i
+ }
+ i += 1
+ }
+ (bestIndex, bestDistance)
+ }
+
+ /**
+ * @return the K-means cost of a given point against the given cluster centers.
+ */
+ def pointCost(
+ centers: TraversableOnce[VectorWithNorm],
+ point: VectorWithNorm): Double = {
+ findClosest(centers, point)._2
+ }
+
+ /**
+ * @return whether a center converged or not, given the epsilon parameter.
+ */
+ def isCenterConverged(
+ oldCenter: VectorWithNorm,
+ newCenter: VectorWithNorm,
+ epsilon: Double): Boolean = {
+ distance(oldCenter, newCenter) <= epsilon
+ }
+
+ /**
+ * @return the distance between two points.
+ */
+ def distance(
+ v1: VectorWithNorm,
+ v2: VectorWithNorm): Double
+
+ /**
+ * @return the total cost of the cluster from its aggregated properties
+ */
+ def clusterCost(
+ centroid: VectorWithNorm,
+ pointsSum: VectorWithNorm,
+ numberOfPoints: Long,
+ pointsSquaredNorm: Double): Double
+
+ /**
+ * Updates the value of `sum` adding the `point` vector.
+ * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+ * @param sum the `sum` for a cluster to be updated
+ */
+ def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+ axpy(1.0, point.vector, sum)
+ }
+
+ /**
+ * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+ *
+ * @param sum the `sum` for a cluster
+ * @param count the number of points in the cluster
+ * @return the centroid of the cluster
+ */
+ def centroid(sum: Vector, count: Long): VectorWithNorm = {
+ scal(1.0 / count, sum)
+ new VectorWithNorm(sum)
+ }
+
+ /**
+ * Returns two new centroids symmetric to the specified centroid applying `noise` with the
+ * with the specified `level`.
+ *
+ * @param level the level of `noise` to apply to the given centroid.
+ * @param noise a noise vector
+ * @param centroid the parent centroid
+ * @return a left and right centroid symmetric to `centroid`
+ */
+ def symmetricCentroids(
+ level: Double,
+ noise: Vector,
+ centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
+ val left = centroid.copy
+ axpy(-level, noise, left)
+ val right = centroid.copy
+ axpy(level, noise, right)
+ (new VectorWithNorm(left), new VectorWithNorm(right))
+ }
+
+ /**
+ * @return the cost of a point to be assigned to the cluster centroid
+ */
+ def cost(
+ point: VectorWithNorm,
+ centroid: VectorWithNorm): Double = distance(point, centroid)
+}
+
+@Since("2.4.0")
+object DistanceMeasure {
+
+ @Since("2.4.0")
+ val EUCLIDEAN = "euclidean"
+ @Since("2.4.0")
+ val COSINE = "cosine"
+
+ private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
+ distanceMeasure match {
+ case EUCLIDEAN => new EuclideanDistanceMeasure
+ case COSINE => new CosineDistanceMeasure
+ case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
+ s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
+ }
+
+ private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
+ distanceMeasure match {
+ case DistanceMeasure.EUCLIDEAN => true
+ case DistanceMeasure.COSINE => true
+ case _ => false
+ }
+ }
+}
+
+private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
+ /**
+ * @return the index of the closest center to the given point, as well as the squared distance.
+ */
+ override def findClosest(
+ centers: TraversableOnce[VectorWithNorm],
+ point: VectorWithNorm): (Int, Double) = {
+ var bestDistance = Double.PositiveInfinity
+ var bestIndex = 0
+ var i = 0
+ centers.foreach { center =>
+ // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
+ // distance computation.
+ var lowerBoundOfSqDist = center.norm - point.norm
+ lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
+ if (lowerBoundOfSqDist < bestDistance) {
+ val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
+ if (distance < bestDistance) {
+ bestDistance = distance
+ bestIndex = i
+ }
+ }
+ i += 1
+ }
+ (bestIndex, bestDistance)
+ }
+
+ /**
+ * @return whether a center converged or not, given the epsilon parameter.
+ */
+ override def isCenterConverged(
+ oldCenter: VectorWithNorm,
+ newCenter: VectorWithNorm,
+ epsilon: Double): Boolean = {
+ EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
+ }
+
+ /**
+ * @param v1: first vector
+ * @param v2: second vector
+ * @return the Euclidean distance between the two input vectors
+ */
+ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+ Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
+ }
+
+ /**
+ * @return the total cost of the cluster from its aggregated properties
+ */
+ override def clusterCost(
+ centroid: VectorWithNorm,
+ pointsSum: VectorWithNorm,
+ numberOfPoints: Long,
+ pointsSquaredNorm: Double): Double = {
+ math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0)
+ }
+
+ /**
+ * @return the cost of a point to be assigned to the cluster centroid
+ */
+ override def cost(
+ point: VectorWithNorm,
+ centroid: VectorWithNorm): Double = {
+ EuclideanDistanceMeasure.fastSquaredDistance(point, centroid)
+ }
+}
+
+
+private[spark] object EuclideanDistanceMeasure {
+ /**
+ * @return the squared Euclidean distance between two vectors computed by
+ * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
+ */
+ private[clustering] def fastSquaredDistance(
+ v1: VectorWithNorm,
+ v2: VectorWithNorm): Double = {
+ MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
+ }
+}
+
+private[spark] class CosineDistanceMeasure extends DistanceMeasure {
+ /**
+ * @param v1: first vector
+ * @param v2: second vector
+ * @return the cosine distance between the two input vectors
+ */
+ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+ assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
+ 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
+ }
+
+ /**
+ * Updates the value of `sum` adding the `point` vector.
+ * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+ * @param sum the `sum` for a cluster to be updated
+ */
+ override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+ assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
+ axpy(1.0 / point.norm, point.vector, sum)
+ }
+
+ /**
+ * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+ *
+ * @param sum the `sum` for a cluster
+ * @param count the number of points in the cluster
+ * @return the centroid of the cluster
+ */
+ override def centroid(sum: Vector, count: Long): VectorWithNorm = {
+ scal(1.0 / count, sum)
+ val norm = Vectors.norm(sum, 2)
+ scal(1.0 / norm, sum)
+ new VectorWithNorm(sum, 1)
+ }
+
+ /**
+ * @return the total cost of the cluster from its aggregated properties
+ */
+ override def clusterCost(
+ centroid: VectorWithNorm,
+ pointsSum: VectorWithNorm,
+ numberOfPoints: Long,
+ pointsSquaredNorm: Double): Double = {
+ val costVector = pointsSum.vector.copy
+ math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0)
+ }
+
+ /**
+ * Returns two new centroids symmetric to the specified centroid applying `noise` with the
+ * with the specified `level`.
+ *
+ * @param level the level of `noise` to apply to the given centroid.
+ * @param noise a noise vector
+ * @param centroid the parent centroid
+ * @return a left and right centroid symmetric to `centroid`
+ */
+ override def symmetricCentroids(
+ level: Double,
+ noise: Vector,
+ centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
+ val (left, right) = super.symmetricCentroids(level, noise, centroid)
+ val leftVector = left.vector
+ val rightVector = right.vector
+ scal(1.0 / left.norm, leftVector)
+ scal(1.0 / right.norm, rightVector)
+ (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 607145cb59fba..b5b1be3490497 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -204,7 +203,7 @@ class KMeans private (
*/
@Since("2.4.0")
def setDistanceMeasure(distanceMeasure: String): this.type = {
- KMeans.validateDistanceMeasure(distanceMeasure)
+ DistanceMeasure.validateDistanceMeasure(distanceMeasure)
this.distanceMeasure = distanceMeasure
this
}
@@ -310,8 +309,7 @@ class KMeans private (
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
- val sum = sums(bestCenter)
- axpy(1.0, point.vector, sum)
+ distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
}
@@ -319,10 +317,9 @@ class KMeans private (
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
- }.mapValues { case (sum, count) =>
- scal(1.0 / count, sum)
- new VectorWithNorm(sum)
- }.collectAsMap()
+ }.collectAsMap().mapValues { case (sum, count) =>
+ distanceMeasureInstance.centroid(sum, count)
+ }
bcCenters.destroy(blocking = false)
@@ -584,14 +581,6 @@ object KMeans {
case _ => false
}
}
-
- private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
- distanceMeasure match {
- case DistanceMeasure.EUCLIDEAN => true
- case DistanceMeasure.COSINE => true
- case _ => false
- }
- }
}
/**
@@ -607,142 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
/** Converts the vector to a dense vector. */
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
-
-
-private[spark] abstract class DistanceMeasure extends Serializable {
-
- /**
- * @return the index of the closest center to the given point, as well as the cost.
- */
- def findClosest(
- centers: TraversableOnce[VectorWithNorm],
- point: VectorWithNorm): (Int, Double) = {
- var bestDistance = Double.PositiveInfinity
- var bestIndex = 0
- var i = 0
- centers.foreach { center =>
- val currentDistance = distance(center, point)
- if (currentDistance < bestDistance) {
- bestDistance = currentDistance
- bestIndex = i
- }
- i += 1
- }
- (bestIndex, bestDistance)
- }
-
- /**
- * @return the K-means cost of a given point against the given cluster centers.
- */
- def pointCost(
- centers: TraversableOnce[VectorWithNorm],
- point: VectorWithNorm): Double = {
- findClosest(centers, point)._2
- }
-
- /**
- * @return whether a center converged or not, given the epsilon parameter.
- */
- def isCenterConverged(
- oldCenter: VectorWithNorm,
- newCenter: VectorWithNorm,
- epsilon: Double): Boolean = {
- distance(oldCenter, newCenter) <= epsilon
- }
-
- /**
- * @return the cosine distance between two points.
- */
- def distance(
- v1: VectorWithNorm,
- v2: VectorWithNorm): Double
-
-}
-
-@Since("2.4.0")
-object DistanceMeasure {
-
- @Since("2.4.0")
- val EUCLIDEAN = "euclidean"
- @Since("2.4.0")
- val COSINE = "cosine"
-
- private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
- distanceMeasure match {
- case EUCLIDEAN => new EuclideanDistanceMeasure
- case COSINE => new CosineDistanceMeasure
- case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
- s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
- }
-}
-
-private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
- /**
- * @return the index of the closest center to the given point, as well as the squared distance.
- */
- override def findClosest(
- centers: TraversableOnce[VectorWithNorm],
- point: VectorWithNorm): (Int, Double) = {
- var bestDistance = Double.PositiveInfinity
- var bestIndex = 0
- var i = 0
- centers.foreach { center =>
- // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
- // distance computation.
- var lowerBoundOfSqDist = center.norm - point.norm
- lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
- if (lowerBoundOfSqDist < bestDistance) {
- val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
- if (distance < bestDistance) {
- bestDistance = distance
- bestIndex = i
- }
- }
- i += 1
- }
- (bestIndex, bestDistance)
- }
-
- /**
- * @return whether a center converged or not, given the epsilon parameter.
- */
- override def isCenterConverged(
- oldCenter: VectorWithNorm,
- newCenter: VectorWithNorm,
- epsilon: Double): Boolean = {
- EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
- }
-
- /**
- * @param v1: first vector
- * @param v2: second vector
- * @return the Euclidean distance between the two input vectors
- */
- override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
- Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
- }
-}
-
-
-private[spark] object EuclideanDistanceMeasure {
- /**
- * @return the squared Euclidean distance between two vectors computed by
- * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
- */
- private[clustering] def fastSquaredDistance(
- v1: VectorWithNorm,
- v2: VectorWithNorm): Double = {
- MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
- }
-}
-
-private[spark] class CosineDistanceMeasure extends DistanceMeasure {
- /**
- * @param v1: first vector
- * @param v2: second vector
- * @return the cosine distance between the two input vectors
- */
- override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
- 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b8a6e94248421..f915062d77389 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.util.BoundedPriorityQueue
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
/**
* Latent Dirichlet Allocation (LDA) model.
@@ -194,6 +194,8 @@ class LocalLDAModel private[spark] (
override protected[spark] val gammaShape: Double = 100)
extends LDAModel with Serializable {
+ private var seed: Long = Utils.random.nextLong()
+
@Since("1.3.0")
override def k: Int = topics.numCols
@@ -216,6 +218,21 @@ class LocalLDAModel private[spark] (
override protected def formatVersion = "1.0"
+ /**
+ * Random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def getSeed: Long = seed
+
+ /**
+ * Set the random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
@@ -298,6 +315,7 @@ class LocalLDAModel private[spark] (
// by topic (columns of lambda)
val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
+ val gammaSeed = this.seed
// Sum bound components for each document:
// component for prob(tokens) + component for prob(document-topic distribution)
@@ -306,7 +324,7 @@ class LocalLDAModel private[spark] (
val localElogbeta = ElogbetaBc.value
var docBound = 0.0D
val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
+ termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id)
val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
// E[log p(doc | theta, beta)]
@@ -352,6 +370,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
documents.map { case (id: Long, termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -362,7 +381,8 @@ class LocalLDAModel private[spark] (
expElogbetaBc.value,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed + id)
(id, Vectors.dense(normalize(gamma, 1.0).toArray))
}
}
@@ -376,6 +396,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
(termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -386,7 +407,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
@@ -403,6 +425,7 @@ class LocalLDAModel private[spark] (
*/
@Since("2.0.0")
def topicDistribution(document: Vector): Vector = {
+ val gammaSeed = this.seed
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t)
if (document.numNonzeros == 0) {
Vectors.zeros(this.k)
@@ -412,7 +435,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
this.docConcentration.asBreeze,
gammaShape,
- this.k)
+ this.k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 693a2a31f026b..f8e5f3ed76457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
val alpha = this.alpha.asBreeze
val gammaShape = this.gammaShape
val optimizeDocConcentration = this.optimizeDocConcentration
+ val seed = randomGenerator.nextLong()
// If and only if optimizeDocConcentration is set true,
// we calculate logphat in the same pass as other statistics.
// No calculation of loghat happens otherwise.
@@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
None
}
- val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs =>
- val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
-
- val stat = BDM.zeros[Double](k, vocabSize)
- val logphatPartOption = logphatPartOptionBase()
- var nonEmptyDocCount: Long = 0L
- nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
- nonEmptyDocCount += 1
- val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, expElogbetaBc.value, alpha, gammaShape, k)
- stat(::, ids) := stat(::, ids) + sstats
- logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
- }
- Iterator((stat, logphatPartOption, nonEmptyDocCount))
+ val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex {
+ (index, docs) =>
+ val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
+
+ val stat = BDM.zeros[Double](k, vocabSize)
+ val logphatPartOption = logphatPartOptionBase()
+ var nonEmptyDocCount: Long = 0L
+ nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
+ nonEmptyDocCount += 1
+ val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index)
+ stat(::, ids) := stat(::, ids) + sstats
+ logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
+ }
+ Iterator((stat, logphatPartOption, nonEmptyDocCount))
}
val elementWiseSum = (
@@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
}
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
- new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
+ new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta)
+ .setSeed(randomGenerator.nextLong())
}
}
@@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer {
expElogbeta: BDM[Double],
alpha: breeze.linalg.Vector[Double],
gammaShape: Double,
- k: Int): (BDV[Double], BDM[Double], List[Int]) = {
+ k: Int,
+ seed: Long): (BDV[Double], BDM[Double], List[Int]) = {
val (ids: List[Int], cts: Array[Double]) = termCounts match {
case v: DenseVector => ((0 until v.size).toList, v.values)
case v: SparseVector => (v.indices.toList, v.values)
}
// Initialize the variational distribution q(theta|gamma) for the mini-batch
+ val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed))
val gammad: BDV[Double] =
- new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
+ new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K
val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
- val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
+ val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
var meanGammaChange = 1D
val ctsVector = new BDV[Double](cts) // ids
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 3ca75e8cdb97a..7a5e520d5818e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom
* $$
* \begin{align}
* c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\
- * n_t+t &= n_t * a + m_t
+ * n_t+1 &= n_t * a + m_t
* \end{align}
* $$
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index 9abdd44a635d1..7b73b286fb91c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -135,7 +135,7 @@ object HashingTF {
private[HashingTF] val Murmur3: String = "murmur3"
- private val seed = 42
+ private[spark] val seed = 42
/**
* Calculate a hash code value for the term object using the native Scala implementation.
@@ -160,7 +160,7 @@ object HashingTF {
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case s: String =>
val utf8 = UTF8String.fromString(s)
- hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed)
+ hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed)
case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " +
s"support type ${term.getClass.getCanonicalName} of input data.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index f6b1143272d16..4f2b7e6f0764e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
*
*/
@Since("1.3.0")
-class FPGrowth private (
+class FPGrowth private[spark] (
private var minSupport: Double,
private var numPartitions: Int) extends Logging with Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 3f8d65a378e2c..7aed2f3bd8a61 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel
*
* @param minSupport the minimal support level of the sequential pattern, any pattern that appears
* more than (minSupport * size-of-the-dataset) times will be output
- * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears
- * less than maxPatternLength will be output
+ * @param maxPatternLength the maximal length of the sequential pattern
* @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
* storage format) allowed in a projected database before local
* processing. If a projected database exceeds this size, another
diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
index 43779878890db..35a250955b282 100644
--- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
+++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
@@ -42,7 +42,12 @@ public void setUp() throws IOException {
@After
public void tearDown() {
- spark.stop();
- spark = null;
+ try {
+ spark.stop();
+ spark = null;
+ } finally {
+ SparkSession.clearDefaultSession();
+ SparkSession.clearActiveSession();
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java
new file mode 100644
index 0000000000000..830f668fe07b8
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.apache.spark.sql.Encoders;
+import org.junit.Test;
+
+import org.apache.spark.SharedSparkSession;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+
+
+public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession {
+
+ private transient Dataset dataset;
+
+ @Override
+ public void setUp() throws IOException {
+ super.setUp();
+ List points = Arrays.asList(0.1, 1.1, 10.1, -1.1);
+
+ dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample");
+ }
+
+ @Test
+ public void testKSTestCDF() {
+ // Create theoretical distributions
+ NormalDistribution stdNormalDist = new NormalDistribution(0, 1);
+
+ // set seeds
+ Long seed = 10L;
+ stdNormalDist.reseedRandomGenerator(seed);
+ Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x);
+
+ double pThreshold = 0.05;
+
+ // Comparing a standard normal sample to a standard normal distribution
+ Row results = KolmogorovSmirnovTest
+ .test(dataset, "sample", stdNormalCDF).head();
+ double pValue1 = results.getDouble(0);
+ // Cannot reject null hypothesis
+ assert(pValue1 > pThreshold);
+ }
+
+ @Test
+ public void testKSTestNamedDistribution() {
+ double pThreshold = 0.05;
+
+ // Comparing a standard normal sample to a standard normal distribution
+ Row results = KolmogorovSmirnovTest
+ .test(dataset, "sample", "norm", 0.0, 1.0).head();
+ double pValue1 = results.getDouble(0);
+ // Cannot reject null hypothesis
+ assert(pValue1 > pThreshold);
+ }
+}
diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
new file mode 100644
index 0000000000000..100ef2545418f
--- /dev/null
+++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister
@@ -0,0 +1,3 @@
+org.apache.spark.ml.util.DuplicateLinearRegressionWriter1
+org.apache.spark.ml.util.DuplicateLinearRegressionWriter2
+org.apache.spark.ml.util.FakeLinearRegressionWriterWithName
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 98c879ece62d6..d3dbb4e754d3d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,17 +21,16 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.ClassificationLeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-class DecisionTreeClassifierSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
import DecisionTreeClassifierSuite.compareAPIs
import testImplicits._
@@ -62,7 +61,8 @@ class DecisionTreeClassifierSuite
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
- val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)
+ val model = new DecisionTreeClassificationModel("dtc",
+ new ClassificationLeafNode(0.0, 0.0, null), 1, 2)
ParamsSuite.checkParams(model)
}
@@ -251,20 +251,33 @@ class DecisionTreeClassifierSuite
MLTestingUtils.checkCopyAndUids(dt, newTree)
- val predictions = newTree.transform(newData)
- .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
- .collect()
-
- predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
- assert(pred === rawPred.argmax,
- s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
- val sum = rawPred.toArray.sum
- assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
- "probability prediction mismatch")
+ testTransformer[(Vector, Double)](newData, newTree,
+ "prediction", "rawPrediction", "probability") {
+ case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
}
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, DecisionTreeClassificationModel](newTree, newData)
+ Vector, DecisionTreeClassificationModel](this, newTree, newData)
+ }
+
+ test("prediction on single instance") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+
+ val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+
+ testPredictionModelSinglePrediction(newTree, newData)
}
test("training with 1-category categorical feature") {
@@ -280,44 +293,6 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}
- test("Use soft prediction for binary classification with ordered categorical features") {
- // The following dataset is set up such that the best split is {1} vs. {0, 2}.
- // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
- val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0)),
- LabeledPoint(1.0, Vectors.dense(0.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(0.0, Vectors.dense(2.0)),
- LabeledPoint(1.0, Vectors.dense(2.0)))
- val data = sc.parallelize(arr)
- val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
-
- // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
- val dt = new DecisionTreeClassifier()
- .setImpurity("gini")
- .setMaxDepth(1)
- .setMaxBins(3)
- val model = dt.fit(df)
- model.rootNode match {
- case n: InternalNode =>
- n.split match {
- case s: CategoricalSplit =>
- assert(s.leftCategories === Array(1.0))
- case other =>
- fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.")
- }
- case other =>
- fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.")
- }
- }
-
test("Feature importance with toy data") {
val dt = new DecisionTreeClassifier()
.setImpurity("gini")
@@ -401,6 +376,32 @@ class DecisionTreeClassifierSuite
testDefaultReadWrite(model)
}
+
+ test("label/impurity stats") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+ val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2)
+ val dt1 = new DecisionTreeClassifier()
+ .setImpurity("entropy")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val model1 = dt1.fit(df)
+
+ val rootNode1 = model1.rootNode
+ assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0))
+
+ val dt2 = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val model2 = dt2.fit(df)
+
+ val rootNode2 = model2.rootNode
+ assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0))
+ }
}
private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 978f89c459f0a..e6d2a8e2b900e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -24,24 +24,23 @@ import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.tree.RegressionLeafNode
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.util.Utils
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
import GBTClassifierSuite.compareAPIs
@@ -71,7 +70,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
- Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
+ Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)),
Array(1.0), 1, 2)
ParamsSuite.checkParams(model)
}
@@ -126,14 +125,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
// should predict all zeros
binaryModel.setThresholds(Array(0.0, 1.0))
- val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect()
- assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+ testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
+ case Row(prediction: Double) => prediction === 0.0
+ }
// should predict all ones
binaryModel.setThresholds(Array(1.0, 0.0))
- val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect()
- assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
-
+ testTransformer[(Double, Vector)](df, binaryModel, "prediction") {
+ case Row(prediction: Double) => prediction === 1.0
+ }
val gbtBase = new GBTClassifier
val model = gbtBase.fit(df)
@@ -141,15 +141,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
// constant threshold scaling is the same as no thresholds
binaryModel.setThresholds(Array(1.0, 1.0))
- val scaledPredictions = binaryModel.transform(df).select("prediction").collect()
- assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
- scaled.getDouble(0) === base.getDouble(0)
- })
+ testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") {
+ scaledPredictions: Seq[Row] =>
+ assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+ scaled.getDouble(0) === base.getDouble(0)
+ })
+ }
// force it to use the predict method
model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1))
- val predictionsWithPredict = model.transform(df).select("prediction").collect()
- assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+ testTransformer[(Double, Vector)](df, model, "prediction") {
+ case Row(prediction: Double) => prediction === 0.0
+ }
}
test("GBTClassifier: Predictor, Classifier methods") {
@@ -169,61 +172,39 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
val blas = BLAS.getInstance()
val validationDataset = validationData.toDF(labelCol, featuresCol)
- val results = gbtModel.transform(validationDataset)
- // check that raw prediction is tree predictions dot tree weights
- results.select(rawPredictionCol, featuresCol).collect().foreach {
- case Row(raw: Vector, features: Vector) =>
+ testTransformer[(Double, Vector)](validationDataset, gbtModel,
+ "rawPrediction", "features", "probability", "prediction") {
+ case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) =>
assert(raw.size === 2)
+ // check that raw prediction is tree predictions dot tree weights
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
- }
- // Compare rawPrediction with probability
- results.select(rawPredictionCol, probabilityCol).collect().foreach {
- case Row(raw: Vector, prob: Vector) =>
- assert(raw.size === 2)
+ // Compare rawPrediction with probability
assert(prob.size === 2)
// Note: we should check other loss types for classification if they are added
val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value))
assert(prob(0) ~== predFromRaw(0) relTol eps)
assert(prob(1) ~== predFromRaw(1) relTol eps)
assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
- }
- // Compare prediction with probability
- results.select(predictionCol, probabilityCol).collect().foreach {
- case Row(pred: Double, prob: Vector) =>
+ // Compare prediction with probability
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
- // force it to use raw2prediction
- gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
- val resultsUsingRaw2Predict =
- gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
- resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, GBTClassificationModel](this, gbtModel, validationDataset)
+ }
- // force it to use probability2prediction
- gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
- val resultsUsingProb2Predict =
- gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
- resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
+ test("prediction on single instance") {
- // force it to use predict
- gbtModel.setRawPredictionCol("").setProbabilityCol("")
- val resultsUsingPredict =
- gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
- resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
+ val gbt = new GBTClassifier().setSeed(123)
+ val trainingDataset = trainData.toDF("label", "features")
+ val gbtModel = gbt.fit(trainingDataset)
- ProbabilisticClassifierSuite.testPredictMethods[
- Vector, GBTClassificationModel](gbtModel, validationDataset)
+ testPredictionModelSinglePrediction(gbtModel, trainingDataset)
}
test("GBT parameter stepSize should be in interval (0, 1]") {
@@ -385,6 +366,78 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
assert(mostImportantFeature !== mostIF)
}
+ test("model evaluateEachIteration") {
+ val gbt = new GBTClassifier()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType("logistic")
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTClassificationModel("gbt-cls-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses)
+ val model2 = new GBTClassificationModel("gbt-cls-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
+
+ val evalArr = model3.evaluateEachIteration(validationData.toDF)
+ val remappedValidationData = validationData.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
+ model1.trees, model1.treeWeights, model1.getOldLossType)
+ val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
+ model2.trees, model2.treeWeights, model2.getOldLossType)
+ val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
+ model3.trees, model3.treeWeights, model3.getOldLossType)
+
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val validationIndicatorCol = "validationIndicator"
+ val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false))
+ val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true))
+
+ val numIter = 20
+ for (lossType <- GBTClassifier.supportedLossTypes) {
+ val gbt = new GBTClassifier()
+ .setSeed(123)
+ .setMaxDepth(2)
+ .setLossType(lossType)
+ .setMaxIter(numIter)
+ val modelWithoutValidation = gbt.fit(trainDF)
+
+ gbt.setValidationIndicatorCol(validationIndicatorCol)
+ val modelWithValidation = gbt.fit(trainDF.union(validationDF))
+
+ assert(modelWithoutValidation.numTrees === numIter)
+ // early stop
+ assert(modelWithValidation.numTrees < numIter)
+
+ val (errorWithoutValidation, errorWithValidation) = {
+ val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType),
+ GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees,
+ modelWithValidation.treeWeights, modelWithValidation.getOldLossType))
+ }
+ assert(errorWithValidation < errorWithoutValidation)
+
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
+ OldAlgo.Classification)
+ assert(evaluationArray.length === numIter)
+ assert(evaluationArray(modelWithValidation.numTrees) >
+ evaluationArray(modelWithValidation.numTrees - 1))
+ var i = 1
+ while (i < modelWithValidation.numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index 41a5d22dd6283..c05c896df5cb1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -21,20 +21,18 @@ import scala.util.Random
import breeze.linalg.{DenseVector => BDV}
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.udf
-class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
threshold: Double,
expected: Set[(Int, Double)]): Unit = {
model.setThreshold(threshold)
- val results = model.transform(df).select("id", "prediction").collect()
- .map(r => (r.getInt(0), r.getDouble(1)))
- .toSet
- assert(results === expected, s"Failed for threshold = $threshold")
+ testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") {
+ rows: Seq[Row] =>
+ val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet
+ assert(results === expected, s"Failed for threshold = $threshold")
+ }
}
def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = {
@@ -202,6 +201,12 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
+ test("prediction on single instance") {
+ val trainer = new LinearSVC()
+ val model = trainer.fit(smallBinaryDataset)
+ testPredictionModelSinglePrediction(model, smallBinaryDataset)
+ }
+
test("linearSVC comparison with R e1071 and scikit-learn") {
val trainer1 = new LinearSVC()
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a5f81a38face9..36b7e51f93d01 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -22,22 +22,20 @@ import scala.language.existentials
import scala.util.Random
import scala.util.control.Breaks._
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors}
import org.apache.spark.ml.optim.aggregator.LogisticAggregator
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, rand}
import org.apache.spark.sql.types.LongType
-class LogisticRegressionSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -332,15 +330,14 @@ class LogisticRegressionSuite
val binaryModel = blr.fit(smallBinaryDataset)
binaryModel.setThreshold(1.0)
- val binaryZeroPredictions =
- binaryModel.transform(smallBinaryDataset).select("prediction").collect()
- assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0))
+ testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") {
+ row => assert(row.getDouble(0) === 0.0)
+ }
binaryModel.setThreshold(0.0)
- val binaryOnePredictions =
- binaryModel.transform(smallBinaryDataset).select("prediction").collect()
- assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0))
-
+ testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") {
+ row => assert(row.getDouble(0) === 1.0)
+ }
val mlr = new LogisticRegression().setFamily("multinomial")
val model = mlr.fit(smallMultinomialDataset)
@@ -348,31 +345,36 @@ class LogisticRegressionSuite
// should predict all zeros
model.setThresholds(Array(1, 1000, 1000))
- val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
- assert(zeroPredictions.forall(_.getDouble(0) === 0.0))
+ testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+ row => assert(row.getDouble(0) === 0.0)
+ }
// should predict all ones
model.setThresholds(Array(1000, 1, 1000))
- val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
- assert(onePredictions.forall(_.getDouble(0) === 1.0))
+ testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+ row => assert(row.getDouble(0) === 1.0)
+ }
// should predict all twos
model.setThresholds(Array(1000, 1000, 1))
- val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
- assert(twoPredictions.forall(_.getDouble(0) === 2.0))
+ testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+ row => assert(row.getDouble(0) === 2.0)
+ }
// constant threshold scaling is the same as no thresholds
model.setThresholds(Array(1000, 1000, 1000))
- val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect()
- assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
- scaled.getDouble(0) === base.getDouble(0)
- })
+ testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model,
+ "prediction") { scaledPredictions: Seq[Row] =>
+ assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
+ scaled.getDouble(0) === base.getDouble(0)
+ })
+ }
// force it to use the predict method
model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1))
- val predictionsWithPredict =
- model.transform(smallMultinomialDataset).select("prediction").collect()
- assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0))
+ testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") {
+ row => assert(row.getDouble(0) === 0.0)
+ }
}
test("logistic regression doesn't fit intercept when fitIntercept is off") {
@@ -403,21 +405,19 @@ class LogisticRegressionSuite
// Modify model params, and check that the params worked.
model.setThreshold(1.0)
- val predAllZero = model.transform(smallBinaryDataset)
- .select("prediction", "myProbability")
- .collect()
- .map { case Row(pred: Double, prob: Vector) => pred }
- assert(predAllZero.forall(_ === 0),
- s"With threshold=1.0, expected predictions to be all 0, but only" +
- s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.")
+ testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(),
+ model, "prediction", "myProbability") { rows =>
+ val predAllZero = rows.map(_.getDouble(0))
+ assert(predAllZero.forall(_ === 0),
+ s"With threshold=1.0, expected predictions to be all 0, but only" +
+ s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.")
+ }
// Call transform with params, and check that the params worked.
- val predNotAllZero =
- model.transform(smallBinaryDataset, model.threshold -> 0.0,
- model.probabilityCol -> "myProb")
- .select("prediction", "myProb")
- .collect()
- .map { case Row(pred: Double, prob: Vector) => pred }
- assert(predNotAllZero.exists(_ !== 0.0))
+ testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(),
+ model.copy(ParamMap(model.threshold -> 0.0,
+ model.probabilityCol -> "myProb")), "prediction", "myProb") {
+ rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0))
+ }
// Call fit() with new params, and check as many params as we can.
lr.setThresholds(Array(0.6, 0.4))
@@ -441,10 +441,10 @@ class LogisticRegressionSuite
val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size
assert(model.numFeatures === numFeatures)
- val results = model.transform(smallMultinomialDataset)
- // check that raw prediction is coefficients dot features + intercept
- results.select("rawPrediction", "features").collect().foreach {
- case Row(raw: Vector, features: Vector) =>
+ testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(),
+ model, "rawPrediction", "features", "probability") {
+ case Row(raw: Vector, features: Vector, prob: Vector) =>
+ // check that raw prediction is coefficients dot features + intercept
assert(raw.size === 3)
val margins = Array.tabulate(3) { k =>
var margin = 0.0
@@ -455,12 +455,7 @@ class LogisticRegressionSuite
margin
}
assert(raw ~== Vectors.dense(margins) relTol eps)
- }
-
- // Compare rawPrediction with probability
- results.select("rawPrediction", "probability").collect().foreach {
- case Row(raw: Vector, prob: Vector) =>
- assert(raw.size === 3)
+ // Compare rawPrediction with probability
assert(prob.size === 3)
val max = raw.toArray.max
val subtract = if (max > 0) max else 0.0
@@ -472,39 +467,8 @@ class LogisticRegressionSuite
assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps)
}
- // Compare prediction with probability
- results.select("prediction", "probability").collect().foreach {
- case Row(pred: Double, prob: Vector) =>
- val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
- assert(pred == predFromProb)
- }
-
- // force it to use raw2prediction
- model.setRawPredictionCol("rawPrediction").setProbabilityCol("")
- val resultsUsingRaw2Predict =
- model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
- resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
- // force it to use probability2prediction
- model.setRawPredictionCol("").setProbabilityCol("probability")
- val resultsUsingProb2Predict =
- model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
- resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
- // force it to use predict
- model.setRawPredictionCol("").setProbabilityCol("")
- val resultsUsingPredict =
- model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
- resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, LogisticRegressionModel](model, smallMultinomialDataset)
+ Vector, LogisticRegressionModel](this, model, smallMultinomialDataset)
}
test("binary logistic regression: Predictor, Classifier methods") {
@@ -517,51 +481,31 @@ class LogisticRegressionSuite
val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size
assert(model.numFeatures === numFeatures)
- val results = model.transform(smallBinaryDataset)
-
- // Compare rawPrediction with probability
- results.select("rawPrediction", "probability").collect().foreach {
- case Row(raw: Vector, prob: Vector) =>
+ testTransformer[(Double, Vector)](smallBinaryDataset.toDF(),
+ model, "rawPrediction", "probability", "prediction") {
+ case Row(raw: Vector, prob: Vector, pred: Double) =>
+ // Compare rawPrediction with probability
assert(raw.size === 2)
assert(prob.size === 2)
val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
assert(prob(1) ~== probFromRaw1 relTol eps)
assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
- }
-
- // Compare prediction with probability
- results.select("prediction", "probability").collect().foreach {
- case Row(pred: Double, prob: Vector) =>
+ // Compare prediction with probability
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
- // force it to use raw2prediction
- model.setRawPredictionCol("rawPrediction").setProbabilityCol("")
- val resultsUsingRaw2Predict =
- model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
- resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
- // force it to use probability2prediction
- model.setRawPredictionCol("").setProbabilityCol("probability")
- val resultsUsingProb2Predict =
- model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
- resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
- // force it to use predict
- model.setRawPredictionCol("").setProbabilityCol("")
- val resultsUsingPredict =
- model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
- resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
- case (pred1, pred2) => assert(pred1 === pred2)
- }
-
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, LogisticRegressionModel](model, smallBinaryDataset)
+ Vector, LogisticRegressionModel](this, model, smallBinaryDataset)
+ }
+
+ test("prediction on single instance") {
+ val blor = new LogisticRegression().setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ testPredictionModelSinglePrediction(blorModel, smallBinaryDataset)
+ val mlor = new LogisticRegression().setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
+ testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset)
}
test("coefficients and intercept methods") {
@@ -616,19 +560,21 @@ class LogisticRegressionSuite
LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)),
LabeledPoint(1.0, Vectors.dense(0.0, -1.0))
).toDF()
- val results = model.transform(overFlowData).select("rawPrediction", "probability").collect()
-
- // probabilities are correct when margins have to be adjusted
- val raw1 = results(0).getAs[Vector](0)
- val prob1 = results(0).getAs[Vector](1)
- assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
- assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
- // probabilities are correct when margins don't have to be adjusted
- val raw2 = results(1).getAs[Vector](0)
- val prob2 = results(1).getAs[Vector](1)
- assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
- assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
+ testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(),
+ model, "rawPrediction", "probability") { results: Seq[Row] =>
+ // probabilities are correct when margins have to be adjusted
+ val raw1 = results(0).getAs[Vector](0)
+ val prob1 = results(0).getAs[Vector](1)
+ assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
+ assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
+
+ // probabilities are correct when margins don't have to be adjusted
+ val raw2 = results(1).getAs[Vector](0)
+ val prob2 = results(1).getAs[Vector](1)
+ assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
+ assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
+ }
}
test("MultiClassSummarizer") {
@@ -2567,10 +2513,13 @@ class LogisticRegressionSuite
val model1 = lr.fit(smallBinaryDataset)
val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial")
val model2 = lr2.fit(smallBinaryDataset)
- val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect()
- val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect()
- predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(p1 === p2)
+ val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect()
+ .map(_.getDouble(0))
+ for (model <- Seq(model1, model2)) {
+ testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model,
+ "prediction") { rows: Seq[Row] =>
+ rows.map(_.getDouble(0)).toArray === binaryExpected
+ }
}
assert(model2.summary.totalIterations === 1)
@@ -2579,10 +2528,13 @@ class LogisticRegressionSuite
val lr4 = new LogisticRegression()
.setInitialModel(model3).setMaxIter(5).setFamily("multinomial")
val model4 = lr4.fit(smallMultinomialDataset)
- val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect()
- val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect()
- predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(p1 === p2)
+ val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction")
+ .collect().map(_.getDouble(0))
+ for (model <- Seq(model3, model4)) {
+ testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model,
+ "prediction") { rows: Seq[Row] =>
+ rows.map(_.getDouble(0)).toArray === multinomialExpected
+ }
}
assert(model4.summary.totalIterations === 1)
}
@@ -2638,8 +2590,8 @@ class LogisticRegressionSuite
LabeledPoint(4.0, Vectors.dense(2.0))).toDF()
val mlr = new LogisticRegression().setFamily("multinomial")
val model = mlr.fit(constantData)
- val results = model.transform(constantData)
- results.select("rawPrediction", "probability", "prediction").collect().foreach {
+ testTransformer[(Double, Vector)](constantData, model,
+ "rawPrediction", "probability", "prediction") {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity)))
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
@@ -2653,8 +2605,8 @@ class LogisticRegressionSuite
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0))).toDF()
val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData)
- val resultsZero = modelZeroLabel.transform(constantZeroData)
- resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach {
+ testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel,
+ "rawPrediction", "probability", "prediction") {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(prob === Vectors.dense(Array(1.0)))
assert(pred === 0.0)
@@ -2666,8 +2618,8 @@ class LogisticRegressionSuite
val constantDataWithMetadata = constantData
.select(constantData("label").as("label", labelMeta), constantData("features"))
val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata)
- val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata)
- resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach {
+ testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata,
+ "rawPrediction", "probability", "prediction") {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0)))
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index d3141ec708560..6b5fe6e49ffea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -17,22 +17,17 @@
package org.apache.spark.ml.classification
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions._
-class MultilayerPerceptronClassifierSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -75,14 +70,24 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
- val result = model.transform(dataset)
MLTestingUtils.checkCopyAndUids(trainer, model)
- val predictionAndLabels = result.select("prediction", "label").collect()
- predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
- assert(p == l)
+ testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") {
+ case Row(p: Double, l: Double) => assert(p == l)
}
}
+ test("prediction on single instance") {
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(123L)
+ .setMaxIter(100)
+ .setSolver("l-bfgs")
+ val model = trainer.fit(dataset)
+ testPredictionModelSinglePrediction(model, dataset)
+ }
+
test("Predicted class probabilities: calibration on toy dataset") {
val layers = Array[Int](4, 5, 2)
@@ -99,13 +104,12 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(strongDataset)
- val result = model.transform(strongDataset)
- result.select("probability", "expectedProbability").collect().foreach {
- case Row(p: Vector, e: Vector) =>
- assert(p ~== e absTol 1e-3)
+ testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model,
+ "probability", "expectedProbability") {
+ case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3)
}
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
+ Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset)
}
test("test model probability") {
@@ -118,11 +122,10 @@ class MultilayerPerceptronClassifierSuite
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
model.setProbabilityCol("probability")
- val result = model.transform(dataset)
- val features2prob = udf { features: Vector => model.mlpModel.predict(features) }
- result.select(features2prob(col("features")), col("probability")).collect().foreach {
- case Row(p1: Vector, p2: Vector) =>
- assert(p1 ~== p2 absTol 1e-3)
+ testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") {
+ case Row(features: Vector, prob: Vector) =>
+ val prob2 = model.mlpModel.predict(features)
+ assert(prob ~== prob2 absTol 1e-3)
}
}
@@ -175,9 +178,6 @@ class MultilayerPerceptronClassifierSuite
val model = trainer.fit(dataFrame)
val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
assert(model.numFeatures === numFeatures)
- val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map {
- case Row(p: Double, l: Double) => (p, l)
- }
// train multinomial logistic regression
val lr = new LogisticRegressionWithLBFGS()
.setIntercept(true)
@@ -189,8 +189,12 @@ class MultilayerPerceptronClassifierSuite
lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label))
// MLP's predictions should not differ a lot from LR's.
val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
- val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
- assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100)
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") {
+ rows: Seq[Row] =>
+ val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1)))
+ val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels))
+ assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100)
+ }
}
test("read/write: MultilayerPerceptronClassifier") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 0d3adf993383f..5f9ab98a2c3ce 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row}
-class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF()
}
- def validatePrediction(predictionAndLabels: DataFrame): Unit = {
- val numOfErrorPredictions = predictionAndLabels.collect().count {
+ def validatePrediction(predictionAndLabels: Seq[Row]): Unit = {
+ val numOfErrorPredictions = predictionAndLabels.filter {
case Row(prediction: Double, label: Double) =>
prediction != label
- }
+ }.length
// At least 80% of the predictions should be on.
- assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
+ assert(numOfErrorPredictions < predictionAndLabels.length / 5)
}
def validateModelFit(
@@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
def validateProbabilities(
- featureAndProbabilities: DataFrame,
+ featureAndProbabilities: Seq[Row],
model: NaiveBayesModel,
modelType: String): Unit = {
- featureAndProbabilities.collect().foreach {
+ featureAndProbabilities.foreach {
case Row(features: Vector, probability: Vector) =>
assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
val expected = modelType match {
@@ -154,15 +153,40 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validationDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
- validatePrediction(predictionAndLabels)
+ testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+ "prediction", "label") { predictionAndLabels: Seq[Row] =>
+ validatePrediction(predictionAndLabels)
+ }
- val featureAndProbabilities = model.transform(validationDataset)
- .select("features", "probability")
- validateProbabilities(featureAndProbabilities, model, "multinomial")
+ testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+ "features", "probability") { featureAndProbabilities: Seq[Row] =>
+ validateProbabilities(featureAndProbabilities, model, "multinomial")
+ }
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, NaiveBayesModel](model, testDataset)
+ Vector, NaiveBayesModel](this, model, testDataset)
+ }
+
+ test("prediction on single instance") {
+ val nPoints = 1000
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+
+ val trainDataset =
+ generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF()
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
+ val model = nb.fit(trainDataset)
+
+ val validationDataset =
+ generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
+
+ testPredictionModelSinglePrediction(model, validationDataset)
}
test("Naive Bayes with weighted samples") {
@@ -210,15 +234,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val validationDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF()
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
- validatePrediction(predictionAndLabels)
+ testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+ "prediction", "label") { predictionAndLabels: Seq[Row] =>
+ validatePrediction(predictionAndLabels)
+ }
- val featureAndProbabilities = model.transform(validationDataset)
- .select("features", "probability")
- validateProbabilities(featureAndProbabilities, model, "bernoulli")
+ testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model,
+ "features", "probability") { featureAndProbabilities: Seq[Row] =>
+ validateProbabilities(featureAndProbabilities, model, "bernoulli")
+ }
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, NaiveBayesModel](model, testDataset)
+ Vector, NaiveBayesModel](this, model, testDataset)
}
test("detect negative values") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 25bad59b9c9cf..2c3417c7e4028 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,26 +17,24 @@
package org.apache.spark.ml.classification
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.StringIndexer
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -74,21 +72,18 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
+ assert(ova.getRawPredictionCol === "rawPrediction")
val ovaModel = ova.fit(dataset)
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
- assert(ovaModel.models.length === numClasses)
+ assert(ovaModel.numClasses === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3))
- val ovaResults = transformedDataset.select("prediction", "label").rdd.map {
- row => (row.getDouble(0), row.getDouble(1))
- }
-
val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses)
lr.optimizer.setRegParam(0.1).setNumIterations(100)
@@ -97,8 +92,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
// determine the #confusion matrix in each class.
// bound how much error we allow compared to multinomial logistic regression.
val expectedMetrics = new MulticlassMetrics(results)
- val ovaMetrics = new MulticlassMetrics(ovaResults)
- assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400)
+
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel,
+ "prediction", "label") { rows =>
+ val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) }
+ val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults))
+ assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400)
+ }
}
test("one-vs-rest: tuning parallelism does not change output") {
@@ -180,6 +180,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
+ ovaModel.setRawPredictionCol("")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred"))
@@ -191,7 +192,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
- assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ assert(output.schema.fieldNames.toSet
+ === Set("label", "features", "prediction", "rawPrediction"))
}
test("SPARK-21306: OneVsRest should support setWeightCol") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index d649ceac949c4..1c8c9829f18d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.MLTest
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.sql.{Dataset, Row}
@@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite {
def testPredictMethods[
FeaturesType,
M <: ProbabilisticClassificationModel[FeaturesType, M]](
- model: M, testData: Dataset[_]): Unit = {
+ mlTest: MLTest, model: M, testData: Dataset[_]): Unit = {
val allColModel = model.copy(ParamMap.empty)
.setRawPredictionCol("rawPredictionAll")
.setProbabilityCol("probabilityAll")
.setPredictionCol("predictionAll")
- val allColResult = allColModel.transform(testData)
+
+ val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol))
+ .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll")
for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
for (probabilityCol <- Seq("", "probabilitySingle")) {
@@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite {
.setProbabilityCol(probabilityCol)
.setPredictionCol(predictionCol)
- val result = newModel.transform(allColResult)
-
- import org.apache.spark.sql.functions._
-
- val resultRawPredictionCol =
- if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
- val resultProbabilityCol =
- if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
- val resultPredictionCol =
- if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
+ import allColResult.sparkSession.implicits._
- result.select(
- resultRawPredictionCol, col("rawPredictionAll"),
- resultProbabilityCol, col("probabilityAll"),
- resultPredictionCol, col("predictionAll")
- ).collect().foreach {
+ mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel,
+ if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol,
+ "rawPredictionAll",
+ if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll",
+ if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll"
+ ) {
case Row(
rawPredictionSingle: Vector, rawPredictionAll: Vector,
probabilitySingle: Vector, probabilityAll: Vector,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 2cca2e6c04698..3062aa9f3d274 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,13 +21,12 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.ClassificationLeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
import testImplicits._
@@ -73,7 +71,8 @@ class RandomForestClassifierSuite
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
+ Array(new DecisionTreeClassificationModel("dtc",
+ new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@@ -143,11 +142,8 @@ class RandomForestClassifierSuite
MLTestingUtils.checkCopyAndUids(rf, model)
- val predictions = model.transform(df)
- .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
- .collect()
-
- predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction",
+ "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
val sum = rawPred.toArray.sum
@@ -155,8 +151,25 @@ class RandomForestClassifierSuite
"probability prediction mismatch")
assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
}
+
ProbabilisticClassifierSuite.testPredictMethods[
- Vector, RandomForestClassificationModel](model, df)
+ Vector, RandomForestClassificationModel](this, model, df)
+ }
+
+ test("prediction on single instance") {
+ val rdd = orderedLabeledPoints5_20
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setSeed(123)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val model = rf.fit(df)
+
+ testPredictionModelSinglePrediction(model, df)
}
test("Fitting without numClasses in metadata") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index fa7471fa2d658..81842afbddbbb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -17,14 +17,20 @@
package org.apache.spark.ml.clustering
-import org.apache.spark.SparkFunSuite
+import scala.language.existentials
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.clustering.DistanceMeasure
import org.apache.spark.sql.Dataset
-class BisectingKMeansSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
final val k = 5
@transient var dataset: Dataset[_] = _
@@ -63,10 +69,13 @@ class BisectingKMeansSuite
// Verify fit does not fail on very sparse data
val model = bkm.fit(sparseDataset)
- val result = model.transform(sparseDataset)
- val numClusters = result.select("prediction").distinct().collect().length
- // Verify we hit the edge case
- assert(numClusters < k && numClusters > 1)
+
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") {
+ rows =>
+ val numClusters = rows.distinct.length
+ // Verify we hit the edge case
+ assert(numClusters < k && numClusters > 1)
+ }
}
test("setter/getter") {
@@ -99,19 +108,16 @@ class BisectingKMeansSuite
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = bkm.fit(dataset)
assert(model.clusterCenters.length === k)
-
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
- val clusters =
- transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
- assert(clusters.size === k)
- assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName) { rows =>
+ val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ }
+
// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
@@ -140,6 +146,62 @@ class BisectingKMeansSuite
testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
BisectingKMeansSuite.allParamSettings, checkModelData)
}
+
+ test("BisectingKMeans with cosine distance is not supported for 0-length vectors") {
+ val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(10.0, 10.0),
+ Vectors.dense(1.0, 0.5)
+ )).map(v => TestRow(v)))
+ val e = intercept[SparkException](model.fit(df))
+ assert(e.getCause.isInstanceOf[AssertionError])
+ assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
+ }
+
+ test("BisectingKMeans with cosine distance") {
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+ Vectors.dense(1.0, 1.0),
+ Vectors.dense(10.0, 10.0),
+ Vectors.dense(1.0, 0.5),
+ Vectors.dense(10.0, 4.4),
+ Vectors.dense(-1.0, 1.0),
+ Vectors.dense(-100.0, 90.0)
+ )).map(v => TestRow(v)))
+ val model = new BisectingKMeans()
+ .setK(3)
+ .setDistanceMeasure(DistanceMeasure.COSINE)
+ .setSeed(1)
+ .fit(df)
+ val predictionDf = model.transform(df)
+ assert(predictionDf.select("prediction").distinct().count() == 3)
+ val predictionsMap = predictionDf.collect().map(row =>
+ row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
+ assert(predictionsMap(Vectors.dense(1.0, 1.0)) ==
+ predictionsMap(Vectors.dense(10.0, 10.0)))
+ assert(predictionsMap(Vectors.dense(1.0, 0.5)) ==
+ predictionsMap(Vectors.dense(10.0, 4.4)))
+ assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
+ predictionsMap(Vectors.dense(-100.0, 90.0)))
+
+ model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
+ }
+
+ test("BisectingKMeans with Array input") {
+ def trainAndComputeCost(dataset: Dataset[_]): Double = {
+ val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
+ model.computeCost(dataset)
+ }
+
+ val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset)
+ val trueCost = trainAndComputeCost(newDataset)
+ val doubleArrayCost = trainAndComputeCost(newDatasetD)
+ val floatArrayCost = trainAndComputeCost(newDatasetF)
+
+ // checking the cost is fine enough as a sanity check
+ assert(trueCost ~== doubleArrayCost absTol 1e-6)
+ assert(trueCost ~== floatArrayCost absTol 1e-6)
+ }
}
object BisectingKMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 08b800b7e4183..0b91f502f615b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -17,21 +17,21 @@
package org.apache.spark.ml.clustering
+import scala.language.existentials
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
-class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
- import testImplicits._
import GaussianMixtureSuite._
+ import testImplicits._
final val k = 5
private val seed = 538009335
@@ -118,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.weights.length === k)
assert(model.gaussians.length === k)
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName, probabilityColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
-
// Check prediction matches the highest probability, and probabilities sum to one.
- transformed.select(predictionColName, probabilityColName).collect().foreach {
- case Row(pred: Int, prob: Vector) =>
+ testTransformer[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName, probabilityColName) {
+ case Row(_, pred: Int, prob: Vector) =>
val probArray = prob.toArray
val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
assert(pred === predFromProb)
@@ -256,6 +251,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues)
assert(symmetricMatrix === expectedMatrix)
}
+
+ test("GaussianMixture with Array input") {
+ def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = {
+ val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
+ model.summary.logLikelihood
+ }
+
+ val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset)
+ val trueLikelihood = trainAndComputlogLikelihood(newDataset)
+ val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD)
+ val floatLikelihood = trainAndComputlogLikelihood(newDatasetF)
+
+ // checking the cost is fine enough as a sanity check
+ assert(trueLikelihood ~== doubleLikelihood absTol 1e-6)
+ assert(trueLikelihood ~== floatLikelihood absTol 1e-6)
+ }
}
object GaussianMixtureSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e4506f23feb31..2569e7a432ca4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -17,19 +17,26 @@
package org.apache.spark.ml.clustering
+import scala.language.existentials
import scala.util.Random
-import org.apache.spark.SparkFunSuite
+import org.dmg.pmml.{ClusteringModel, PMML}
+
+import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
+ KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector)
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
+
+ import testImplicits._
final val k = 5
@transient var dataset: Dataset[_] = _
@@ -103,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
val model = kmeans.fit(dataset)
assert(model.clusterCenters.length === k)
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName) { rows =>
+ val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
}
- val clusters =
- transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
- assert(clusters.size === k)
- assert(clusters === Set(0, 1, 2, 3, 4))
+
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
@@ -143,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName))
- Seq(featuresColName, predictionColName).foreach { column =>
- assert(transformed.columns.contains(column))
- }
+ assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName))
assert(model.getFeaturesCol == featuresColName)
assert(model.getPredictionCol == predictionColName)
}
@@ -179,8 +182,38 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
predictionsMap(Vectors.dense(-100.0, 90.0)))
+ model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
+ }
+
+ test("KMeans with cosine distance is not supported for 0-length vectors") {
+ val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(10.0, 10.0),
+ Vectors.dense(1.0, 0.5)
+ )).map(v => TestRow(v)))
+ val e = intercept[SparkException](model.fit(df))
+ assert(e.getCause.isInstanceOf[AssertionError])
+ assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
+ }
+
+ test("KMean with Array input") {
+ def trainAndComputeCost(dataset: Dataset[_]): Double = {
+ val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset)
+ model.computeCost(dataset)
+ }
+
+ val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset)
+ val trueCost = trainAndComputeCost(newDataset)
+ val doubleArrayCost = trainAndComputeCost(newDatasetD)
+ val floatArrayCost = trainAndComputeCost(newDatasetF)
+
+ // checking the cost is fine enough as a sanity check
+ assert(trueCost ~== doubleArrayCost absTol 1e-6)
+ assert(trueCost ~== floatArrayCost absTol 1e-6)
}
+
test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
@@ -189,6 +222,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
KMeansSuite.allParamSettings, checkModelData)
}
+
+ test("pmml export") {
+ val clusterCenters = Array(
+ MLlibVectors.dense(1.0, 2.0, 6.0),
+ MLlibVectors.dense(1.0, 3.0, 0.0),
+ MLlibVectors.dense(1.0, 4.0, 6.0))
+ val oldKmeansModel = new MLlibKMeansModel(clusterCenters)
+ val kmeansModel = new KMeansModel("", oldKmeansModel)
+ def checkModel(pmml: PMML): Unit = {
+ // Check the header descripiton is what we expect
+ assert(pmml.getHeader.getDescription === "k-means clustering")
+ // check that the number of fields match the single vector size
+ assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size)
+ // This verify that there is a model attached to the pmml object and the model is a clustering
+ // one. It also verifies that the pmml model has the same number of clusters of the spark
+ // model.
+ val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel]
+ assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length)
+ }
+ testPMMLWrite(sc, kmeansModel, checkModel)
+ }
}
object KMeansSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index e73bbc18d76bd..db92132d18b7b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.ml.clustering
+import scala.language.existentials
+
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql._
-
object LDASuite {
def generateLDAData(
spark: SparkSession,
@@ -35,9 +34,8 @@ object LDASuite {
vocabSize: Int): DataFrame = {
val avgWC = 1 // average instances of each word in a doc
val sc = spark.sparkContext
- val rng = new java.util.Random()
- rng.setSeed(1)
val rdd = sc.parallelize(1 to rows).map { i =>
+ val rng = new java.util.Random(i)
Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
}.map(v => new TestRow(v))
spark.createDataFrame(rdd)
@@ -60,7 +58,7 @@ object LDASuite {
}
-class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LDASuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -185,16 +183,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
assert(model.topicsMatrix.numCols === k)
assert(!model.isDistributed)
- // transform()
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", lda.getTopicDistributionCol)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
- transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
- val topicDistribution = r.getAs[Vector](0)
- assert(topicDistribution.size === k)
- assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
+ testTransformer[Tuple1[Vector]](dataset.toDF(), model,
+ "features", lda.getTopicDistributionCol) {
+ case Row(_, topicDistribution: Vector) =>
+ assert(topicDistribution.size === k)
+ assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
}
// logLikelihood, logPerplexity
@@ -252,6 +245,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
LDASuite.allParamSettings, checkModelData)
+
+ // Make sure the result is deterministic after saving and loading the model
+ val model = lda.fit(dataset)
+ val model2 = testDefaultReadWrite(model)
+ assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6)
+ assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6)
}
test("read/write DistributedLDAModel") {
@@ -323,4 +322,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
assert(model.getOptimizer === optimizer)
}
}
+
+ test("LDA with Array input") {
+ def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = {
+ val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset)
+ (model.logLikelihood(dataset), model.logPerplexity(dataset))
+ }
+
+ val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset)
+ val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset)
+ val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD)
+ val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF)
+ // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210)
+ assert(llD <= 0.0 && llD != Double.NegativeInfinity)
+ assert(llF <= 0.0 && llF != Double.NegativeInfinity)
+ assert(lpD >= 0.0 && lpD != Double.NegativeInfinity)
+ assert(lpF >= 0.0 && lpF != Double.NegativeInfinity)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
new file mode 100644
index 0000000000000..b7072728d48f0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types._
+
+
+class PowerIterationClusteringSuite extends SparkFunSuite
+ with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ @transient var data: Dataset[_] = _
+ final val r1 = 1.0
+ final val n1 = 10
+ final val r2 = 4.0
+ final val n2 = 40
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2)
+ }
+
+ test("default parameters") {
+ val pic = new PowerIterationClustering()
+
+ assert(pic.getK === 2)
+ assert(pic.getMaxIter === 20)
+ assert(pic.getInitMode === "random")
+ assert(pic.getSrcCol === "src")
+ assert(pic.getDstCol === "dst")
+ assert(!pic.isDefined(pic.weightCol))
+ }
+
+ test("parameter validation") {
+ intercept[IllegalArgumentException] {
+ new PowerIterationClustering().setK(1)
+ }
+ intercept[IllegalArgumentException] {
+ new PowerIterationClustering().setInitMode("no_such_a_mode")
+ }
+ intercept[IllegalArgumentException] {
+ new PowerIterationClustering().setSrcCol("")
+ }
+ intercept[IllegalArgumentException] {
+ new PowerIterationClustering().setDstCol("")
+ }
+ }
+
+ test("power iteration clustering") {
+ val n = n1 + n2
+
+ val assignments = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
+ (n1 until n).map(x => (x, 0)).toSet
+ assert(localAssignments === expectedResult)
+
+ val assignments2 = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(10)
+ .setInitMode("degree")
+ .setWeightCol("weight")
+ .assignClusters(data)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+ assert(localAssignments2 === expectedResult)
+ }
+
+ test("supported input types") {
+ val pic = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(1)
+ .setWeightCol("weight")
+
+ def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = {
+ val typedData = data.select(
+ col("src").cast(srcType).alias("src"),
+ col("dst").cast(dstType).alias("dst"),
+ col("weight").cast(weightType).alias("weight")
+ )
+ pic.assignClusters(typedData).collect()
+ }
+
+ for (srcType <- Seq(IntegerType, LongType)) {
+ runTest(srcType, LongType, DoubleType)
+ }
+ for (dstType <- Seq(IntegerType, LongType)) {
+ runTest(LongType, dstType, DoubleType)
+ }
+ for (weightType <- Seq(FloatType, DoubleType)) {
+ runTest(LongType, LongType, weightType)
+ }
+ }
+
+ test("invalid input: negative similarity") {
+ val pic = new PowerIterationClustering()
+ .setMaxIter(1)
+ .setWeightCol("weight")
+ val badData = spark.createDataFrame(Seq(
+ (0, 1, -1.0),
+ (1, 0, -1.0)
+ )).toDF("src", "dst", "weight")
+ val msg = intercept[SparkException] {
+ pic.assignClusters(badData)
+ }.getCause.getMessage
+ assert(msg.contains("Similarity must be nonnegative"))
+ }
+
+ test("test default weight") {
+ val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst)
+
+ val assignments = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithoutWeight)
+ val localAssignments = assignments
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0))
+
+ val assignments2 = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIter(40)
+ .assignClusters(dataWithWeightOne)
+ val localAssignments2 = assignments2
+ .select('id, 'cluster)
+ .as[(Long, Int)].collect().toSet
+
+ assert(localAssignments === localAssignments2)
+ }
+
+ test("read/write") {
+ val t = new PowerIterationClustering()
+ .setK(4)
+ .setMaxIter(100)
+ .setInitMode("degree")
+ .setSrcCol("src1")
+ .setDstCol("dst1")
+ .setWeightCol("weight")
+ testDefaultReadWrite(t)
+ }
+}
+
+object PowerIterationClusteringSuite {
+
+ /** Generates a circle of points. */
+ private def genCircle(r: Double, n: Int): Array[(Double, Double)] = {
+ Array.tabulate(n) { i =>
+ val theta = 2.0 * math.Pi * i / n
+ (r * math.cos(theta), r * math.sin(theta))
+ }
+ }
+
+ /** Computes Gaussian similarity. */
+ private def sim(x: (Double, Double), y: (Double, Double)): Double = {
+ val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2)
+ math.exp(-dist2 / 2.0)
+ }
+
+ def generatePICData(
+ spark: SparkSession,
+ r1: Double,
+ r2: Double,
+ n1: Int,
+ n2: Int): DataFrame = {
+ // Generate two circles following the example in the PIC paper.
+ val n = n1 + n2
+ val points = genCircle(r1, n1) ++ genCircle(r2, n2)
+
+ val rows = (for (i <- 1 until n) yield {
+ for (j <- 0 until i) yield {
+ (i.toLong, j.toLong, sim(points(i), points(j)))
+ }
+ }).flatMap(_.iterator)
+
+ spark.createDataFrame(rows).toDF("src", "dst", "weight")
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
index 677ce49a903ab..2c175ff68e0b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.ml.evaluation
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.TestingUtils._
@@ -66,16 +68,57 @@ class ClusteringEvaluatorSuite
assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5)
}
- test("number of clusters must be greater than one") {
- val singleClusterDataset = irisDataset.where($"label" === 0.0)
+ /*
+ Use the following python code to load the data and evaluate it using scikit-learn package.
+
+ from sklearn import datasets
+ from sklearn.metrics import silhouette_score
+ iris = datasets.load_iris()
+ round(silhouette_score(iris.data, iris.target, metric='cosine'), 10)
+
+ 0.7222369298
+ */
+ test("cosine Silhouette") {
val evaluator = new ClusteringEvaluator()
.setFeaturesCol("features")
.setPredictionCol("label")
+ .setDistanceMeasure("cosine")
+
+ assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5)
+ }
- val e = intercept[AssertionError]{
- evaluator.evaluate(singleClusterDataset)
+ test("number of clusters must be greater than one") {
+ val singleClusterDataset = irisDataset.where($"label" === 0.0)
+ Seq("squaredEuclidean", "cosine").foreach { distanceMeasure =>
+ val evaluator = new ClusteringEvaluator()
+ .setFeaturesCol("features")
+ .setPredictionCol("label")
+ .setDistanceMeasure(distanceMeasure)
+
+ val e = intercept[AssertionError] {
+ evaluator.evaluate(singleClusterDataset)
+ }
+ assert(e.getMessage.contains("Number of clusters must be greater than one"))
}
- assert(e.getMessage.contains("Number of clusters must be greater than one"))
}
+ test("SPARK-23568: we should use metadata to determine features number") {
+ val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size
+ val attrGroup = new AttributeGroup("features", attributesNum)
+ val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label")
+ require(AttributeGroup.fromStructField(df.schema("features"))
+ .numAttributes.isDefined, "numAttributes metadata should be defined")
+ val evaluator = new ClusteringEvaluator()
+ .setFeaturesCol("features")
+ .setPredictionCol("label")
+
+ // with the proper metadata we compute correctly the result
+ assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5)
+
+ val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1)
+ val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()),
+ $"label")
+ // with wrong metadata the evaluator throws an Exception
+ intercept[SparkException](evaluator.evaluate(dfWrong))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 4455d35210878..05d4a6ee2dabf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -17,14 +17,12 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class BinarizerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setInputCol("feature")
.setOutputCol("binarized_feature")
- binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach {
+ testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after binarization.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index 7175c721bff36..9b823259b1deb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -20,16 +20,15 @@ package org.apache.spark.ml.feature
import breeze.numerics.{cos, sin}
import breeze.numerics.constants.Pi
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
-class BucketedRandomProjectionLSHSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
@transient var dataset: Dataset[_] = _
@@ -49,6 +48,14 @@ class BucketedRandomProjectionLSHSuite
ParamsSuite.checkParams(model)
}
+ test("setters") {
+ val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
+ .setInputCol("testkeys")
+ .setOutputCol("testvalues")
+ assert(model.getInputCol === "testkeys")
+ assert(model.getOutputCol === "testvalues")
+ }
+
test("BucketedRandomProjectionLSH: default params") {
val brp = new BucketedRandomProjectionLSH
assert(brp.getNumHashTables === 1.0)
@@ -98,6 +105,21 @@ class BucketedRandomProjectionLSHSuite
MLTestingUtils.checkCopyAndUids(brp, brpModel)
}
+ test("BucketedRandomProjectionLSH: streaming transform") {
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(2)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(1.0)
+ .setSeed(12345)
+ val brpModel = brp.fit(dataset)
+
+ testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") {
+ case Row(values: Seq[_]) =>
+ assert(values.length === brp.getNumHashTables)
+ }
+ }
+
test("BucketedRandomProjectionLSH: test of LSH property") {
// Project from 2 dimensional Euclidean Space to 1 dimensions
val brp = new BucketedRandomProjectionLSH()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 7403680ae3fdc..9ea15e1918532 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class BucketizerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCol("result")
.setSplits(splits)
- bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+ testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCol("result")
.setSplits(splits)
- bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+ testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setSplits(splits)
bucketizer.setHandleInvalid("keep")
- bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+ testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -172,7 +171,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setSplits(Array(0.1, 0.8, 0.9))
- testDefaultReadWrite(t)
+
+ val bucketizer = testDefaultReadWrite(t)
+ val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
+ bucketizer.transform(data)
}
test("Bucket numeric features") {
@@ -327,7 +329,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setInputCols(Array("myInputCol"))
.setOutputCols(Array("myOutputCol"))
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
- testDefaultReadWrite(t)
+
+ val bucketizer = testDefaultReadWrite(t)
+ val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
+ bucketizer.transform(data)
+ assert(t.hasDefault(t.outputCol))
+ assert(bucketizer.hasDefault(bucketizer.outputCol))
}
test("Bucketizer in a pipeline") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index c83909c4498f2..c843df9f33e3e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -17,16 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
-class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
@transient var dataset: Dataset[_] = _
@@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
test("Test Chi-Square selector: numTopFeatures") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
- val model = ChiSqSelectorSuite.testSelector(selector, dataset)
+ val model = testSelector(selector, dataset)
MLTestingUtils.checkCopyAndUids(selector, model)
}
test("Test Chi-Square selector: percentile") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17)
- ChiSqSelectorSuite.testSelector(selector, dataset)
+ testSelector(selector, dataset)
}
test("Test Chi-Square selector: fpr") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02)
- ChiSqSelectorSuite.testSelector(selector, dataset)
+ testSelector(selector, dataset)
}
test("Test Chi-Square selector: fdr") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12)
- ChiSqSelectorSuite.testSelector(selector, dataset)
+ testSelector(selector, dataset)
}
test("Test Chi-Square selector: fwe") {
val selector = new ChiSqSelector()
.setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12)
- ChiSqSelectorSuite.testSelector(selector, dataset)
+ testSelector(selector, dataset)
}
test("read/write") {
@@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(expected.selectedFeatures === actual.selectedFeatures)
}
}
-}
-object ChiSqSelectorSuite {
-
- private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = {
- val selectorModel = selector.fit(dataset)
- selectorModel.transform(dataset).select("filtered", "topFeature").collect()
- .foreach { case Row(vec1: Vector, vec2: Vector) =>
+ private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = {
+ val selectorModel = selector.fit(data)
+ testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel,
+ "filtered", "topFeature") {
+ case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
- }
+ }
selectorModel
}
+}
+
+object ChiSqSelectorSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 1784c07ca23e3..61217669d9277 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -16,16 +16,13 @@
*/
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class CountVectorizerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
- cv.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
MLTestingUtils.checkCopyAndUids(cv, cvm)
assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
- cvm.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.fit(df)
assert(cvModel2.vocabulary === Array("a", "b"))
- cvModel2.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.fit(df)
assert(cvModel3.vocabulary === Array("a", "b"))
- cvModel3.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.setInputCol("words")
.setOutputCol("features")
.setMinTF(3)
- cv.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.setInputCol("words")
.setOutputCol("features")
.setMinTF(0.3)
- cv.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.setOutputCol("features")
.setBinary(true)
.fit(df)
- cv.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
@@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
.setInputCol("words")
.setOutputCol("features")
.setBinary(true)
- cv2.transform(df).select("features", "expected").collect().foreach {
+ testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 8dd3dd75e1be5..6734336aac39c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -21,16 +21,14 @@ import scala.beans.BeanInfo
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.Row
@BeanInfo
case class DCTTestData(vec: Vector, wantedVec: Vector)
-class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class DCTSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
.setOutputCol("resultVec")
.setInverse(inverse)
- transformer.transform(dataset)
- .select("resultVec", "wantedVec")
- .collect()
- .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
- assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
+ testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") {
+ case Row(resultVec: Vector, wantedVec: Vector) =>
+ assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
index a4cca27be7815..3a8d0762e2ab7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
@@ -17,13 +17,31 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.Row
-class ElementwiseProductSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ test("streaming transform") {
+ val scalingVec = Vectors.dense(0.1, 10.0)
+ val data = Seq(
+ (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)),
+ (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0))
+ )
+ val df = spark.createDataFrame(data).toDF("features", "expected")
+ val ep = new ElementwiseProduct()
+ .setInputCol("features")
+ .setOutputCol("actual")
+ .setScalingVec(scalingVec)
+ testTransformer[(Vector, Vector)](df, ep, "actual", "expected") {
+ case Row(actual: Vector, expected: Vector) =>
+ assert(actual ~== expected relTol 1e-14)
+ }
+ }
test("read/write") {
val ep = new ElementwiseProduct()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
index 3fc3cbb62d5b5..d799ba6011fa8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
@@ -17,26 +17,24 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
-class FeatureHasherSuite extends SparkFunSuite
- with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class FeatureHasherSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
- import HashingTFSuite.murmur3FeatureIdx
+ import FeatureHasherSuite.murmur3FeatureIdx
- implicit private val vectorEncoder = ExpressionEncoder[Vector]()
+ implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]()
test("params") {
ParamsSuite.checkParams(new FeatureHasher)
@@ -51,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite
}
test("feature hashing") {
+ val numFeatures = 100
+ // Assume perfect hash on field names in computing expected results
+ def idx: Any => Int = murmur3FeatureIdx(numFeatures)
+
val df = Seq(
- (2.0, true, "1", "foo"),
- (3.0, false, "2", "bar")
- ).toDF("real", "bool", "stringNum", "string")
+ (2.0, true, "1", "foo",
+ Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0),
+ (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))),
+ (3.0, false, "2", "bar",
+ Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0),
+ (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))))
+ ).toDF("real", "bool", "stringNum", "string", "expected")
- val n = 100
val hasher = new FeatureHasher()
.setInputCols("real", "bool", "stringNum", "string")
.setOutputCol("features")
- .setNumFeatures(n)
+ .setNumFeatures(numFeatures)
val output = hasher.transform(df)
val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
- assert(attrGroup.numAttributes === Some(n))
+ assert(attrGroup.numAttributes === Some(numFeatures))
- val features = output.select("features").as[Vector].collect()
- // Assume perfect hash on field names
- def idx: Any => Int = murmur3FeatureIdx(n)
- // check expected indices
- val expected = Seq(
- Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0),
- (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))),
- Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0),
- (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))
- )
- assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
+ testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14 )
+ }
}
test("setting explicit numerical columns to treat as categorical") {
@@ -216,3 +214,11 @@ class FeatureHasherSuite extends SparkFunSuite
testDefaultReadWrite(t)
}
}
+
+object FeatureHasherSuite {
+
+ private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
+ Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index a46272fdce1fb..c5183ecfef7d7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -17,17 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
import org.apache.spark.util.Utils
-class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class HashingTFSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
import HashingTFSuite.murmur3FeatureIdx
@@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("hashingTF") {
- val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
- val n = 100
+ val numFeatures = 100
+ // Assume perfect hash when computing expected features.
+ def idx: Any => Int = murmur3FeatureIdx(numFeatures)
+ val data = Seq(
+ ("a a b b c d".split(" ").toSeq,
+ Vectors.sparse(numFeatures,
+ Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))))
+ )
+
+ val df = data.toDF("words", "expected")
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
- .setNumFeatures(n)
+ .setNumFeatures(numFeatures)
val output = hashingTF.transform(df)
val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
- require(attrGroup.numAttributes === Some(n))
- val features = output.select("features").first().getAs[Vector](0)
- // Assume perfect hash on "a", "b", "c", and "d".
- def idx: Any => Int = murmur3FeatureIdx(n)
- val expected = Vectors.sparse(n,
- Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
- assert(features ~== expected absTol 1e-14)
+ require(attrGroup.numAttributes === Some(numFeatures))
+
+ testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
}
test("applying binary term freqs") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 005edf73d29be..cdd62be43b54c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
-class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class IDFSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
)
- val numOfData = data.size
+ val numOfData = data.length
val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
math.log((numOfData + 1.0) / (x + 1.0))
})
@@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
MLTestingUtils.checkCopyAndUids(idfEst, idfModel)
- idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
@@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
)
- val numOfData = data.size
+ val numOfData = data.length
val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
})
@@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
.setMinDocFreq(1)
.fit(df)
- idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
index c08b35b419266..75f63a623e6d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
@@ -16,13 +16,12 @@
*/
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.SparkException
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class ImputerSuite extends MLTest with DefaultReadWriteTest {
test("Imputer for Double with default missing Value NaN") {
val df = spark.createDataFrame( Seq(
@@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default
ImputerSuite.iterateStrategyTest(imputer, df)
}
+ test("Imputer should work with Structured Streaming") {
+ val localSpark = spark
+ import localSpark.implicits._
+ val df = Seq[(java.lang.Double, Double)](
+ (4.0, 4.0),
+ (10.0, 10.0),
+ (10.0, 10.0),
+ (Double.NaN, 8.0),
+ (null, 8.0)
+ ).toDF("value", "expected_mean_value")
+ val imputer = new Imputer()
+ .setInputCols(Array("value"))
+ .setOutputCols(Array("out"))
+ .setStrategy("mean")
+ val model = imputer.fit(df)
+ testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") {
+ case Row(exp: java.lang.Double, out: Double) =>
+ assert((exp.isNaN && out.isNaN) || (exp == out),
+ s"Imputed values differ. Expected: $exp, actual: $out")
+ }
+ }
+
test("Imputer throws exception when surrogate cannot be computed") {
val df = spark.createDataFrame( Seq(
(0, Double.NaN, 1.0, 1.0),
@@ -164,8 +185,6 @@ object ImputerSuite {
* @param df DataFrame with columns "id", "value", "expected_mean", "expected_median"
*/
def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = {
- val inputCols = imputer.getInputCols
-
Seq("mean", "median").foreach { strategy =>
imputer.setStrategy(strategy)
val model = imputer.fit(df)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
index 54f059e5f143e..eea31fc7ae3f2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class InteractionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
test("numeric interaction") {
val data = Seq(
- (2, Vectors.dense(3.0, 4.0)),
- (1, Vectors.dense(1.0, 5.0))
- ).toDF("a", "b")
+ (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
+ (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))
+ ).toDF("a", "b", "expected")
val groupAttr = new AttributeGroup(
"b",
Array[Attribute](
@@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
NumericAttribute.defaultAttr.withName("bar")))
val df = data.select(
col("a").as("a", NumericAttribute.defaultAttr.toMetadata()),
- col("b").as("b", groupAttr.toMetadata()))
+ col("b").as("b", groupAttr.toMetadata()),
+ col("expected"))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
+ testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features === expected)
+ }
+
val res = trans.transform(df)
- val expected = Seq(
- (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)),
- (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))
- ).toDF("a", "b", "features")
- assert(res.collect() === expected.collect())
val attrs = AttributeGroup.fromStructField(res.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
@@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
test("nominal interaction") {
val data = Seq(
- (2, Vectors.dense(3.0, 4.0)),
- (1, Vectors.dense(1.0, 5.0))
- ).toDF("a", "b")
+ (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
+ (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))
+ ).toDF("a", "b", "expected")
val groupAttr = new AttributeGroup(
"b",
Array[Attribute](
@@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def
val df = data.select(
col("a").as(
"a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()),
- col("b").as("b", groupAttr.toMetadata()))
+ col("b").as("b", groupAttr.toMetadata()),
+ col("expected"))
val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features")
+ testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features === expected)
+ }
+
val res = trans.transform(df)
- val expected = Seq(
- (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)),
- (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))
- ).toDF("a", "b", "features")
- assert(res.collect() === expected.collect())
val attrs = AttributeGroup.fromStructField(res.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
index 918da4f9388d4..8dd0f0cb91e37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
@@ -14,15 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row
-class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
.setOutputCol("scaled")
val model = scaler.fit(df)
- model.transform(df).select("expected", "scaled").collect()
- .foreach { case Row(vector1: Vector, vector2: Vector) =>
- assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1")
+ testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
+ case Row(expectedVec: Vector, actualVec: Vector) =>
+ assert(expectedVec === actualVec,
+ s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec")
}
MLTestingUtils.checkCopyAndUids(scaler, model)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 96df68dbdf053..1c2956cb82908 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.sql.{Dataset, Row}
-class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+class MinHashLSHSuite extends MLTest with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@@ -43,6 +42,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
ParamsSuite.checkParams(model)
}
+ test("setters") {
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+ .setInputCol("testkeys")
+ .setOutputCol("testvalues")
+ assert(model.getInputCol === "testkeys")
+ assert(model.getOutputCol === "testvalues")
+ }
+
test("MinHashLSH: default params") {
val rp = new MinHashLSH
assert(rp.getNumHashTables === 1.0)
@@ -167,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(precision == 1.0)
assert(recall >= 0.7)
}
+
+ test("MinHashLSHModel.transform should work with Structured Streaming") {
+ val localSpark = spark
+ import localSpark.implicits._
+
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+ model.set(model.inputCol, "keys")
+ testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) {
+ case Row(_: Vector, output: Seq[_]) =>
+ assert(output.length === model.randCoefficients.length)
+ // no AND-amplification yet: SPARK-18450, so each hash output is of length 1
+ output.foreach {
+ case hashOutput: Vector => assert(hashOutput.size === 1)
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 51db74eb739ca..2d965f2ca2c54 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row
-class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
.setMax(5)
val model = scaler.fit(df)
- model.transform(df).select("expected", "scaled").collect()
- .foreach { case Row(vector1: Vector, vector2: Vector) =>
- assert(vector1.equals(vector2), "Transformed vector is different with expected.")
+ testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
+ case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1 === vector2, "Transformed vector is different with expected.")
}
MLTestingUtils.checkCopyAndUids(scaler, model)
@@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
val model = scaler.fit(df)
model.transform(df).select("expected", "scaled").collect()
.foreach { case Row(vector1: Vector, vector2: Vector) =>
- assert(vector1.equals(vector2), "Transformed vector is different with expected.")
+ assert(vector1 === vector2, "Transformed vector is different with expected.")
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index d4975c0b4e20e..201a335e0d7be 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -19,17 +19,15 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
+
@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
-class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NGramSuite extends MLTest with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.NGramSuite._
import testImplicits._
test("default behavior yields bigram features") {
@@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
.setN(3)
testDefaultReadWrite(t)
}
-}
-
-object NGramSuite extends SparkFunSuite {
- def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
- t.transform(dataset)
- .select("nGrams", "wantedNGrams")
- .collect()
- .foreach { case Row(actualNGrams, wantedNGrams) =>
+ def testNGram(t: NGram, dataFrame: DataFrame): Unit = {
+ testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") {
+ case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) =>
assert(actualNGrams === wantedNGrams)
- }
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index c75027fb4553d..eff57f1223af4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,21 +17,17 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class NormalizerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@transient var data: Array[Vector] = _
- @transient var dataFrame: DataFrame = _
- @transient var normalizer: Normalizer = _
@transient var l1Normalized: Array[Vector] = _
@transient var l2Normalized: Array[Vector] = _
@@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.dense(0.897906166, 0.113419726, 0.42532397),
Vectors.sparse(3, Seq())
)
-
- dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF()
- normalizer = new Normalizer()
- .setInputCol("features")
- .setOutputCol("normalized_features")
- }
-
- def collectResult(result: DataFrame): Array[Vector] = {
- result.select("normalized_features").collect().map {
- case Row(features: Vector) => features
- }
}
- def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
- assert((lhs, rhs).zipped.forall {
+ def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
+ assert((lhs, rhs) match {
case (v1: DenseVector, v2: DenseVector) => true
case (v1: SparseVector, v2: SparseVector) => true
case _ => false
}, "The vector type should be preserved after normalization.")
}
- def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = {
- assert((lhs, rhs).zipped.forall { (vector1, vector2) =>
- vector1 ~== vector2 absTol 1E-5
- }, "The vector value is not correct after normalization.")
+ def assertValues(lhs: Vector, rhs: Vector): Unit = {
+ assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.")
}
test("Normalization with default parameter") {
- val result = collectResult(normalizer.transform(dataFrame))
-
- assertTypeOfVector(data, result)
+ val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized")
+ val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected")
- assertValues(result, l2Normalized)
+ testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
+ case Row(features: Vector, normalized: Vector, expected: Vector) =>
+ assertTypeOfVector(normalized, features)
+ assertValues(normalized, expected)
+ }
}
test("Normalization with setter") {
- normalizer.setP(1)
+ val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected")
+ val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1)
- val result = collectResult(normalizer.transform(dataFrame))
-
- assertTypeOfVector(data, result)
-
- assertValues(result, l1Normalized)
+ testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") {
+ case Row(features: Vector, normalized: Vector, expected: Vector) =>
+ assertTypeOfVector(normalized, features)
+ assertValues(normalized, expected)
+ }
}
test("read/write") {
@@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testDefaultReadWrite(t)
}
}
-
-private object NormalizerSuite {
- case class FeatureData(features: Vector)
-}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
index 1d3f845586426..d549e13262273 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala
@@ -17,18 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{Encoder, Row}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
-class OneHotEncoderEstimatorSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite
assert(encoder.getDropLast === true)
encoder.setDropLast(false)
assert(encoder.getDropLast === false)
-
val model = encoder.fit(df)
- val encoded = model.transform(df)
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector)](df, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
}
}
@@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite
.setOutputCols(Array("output"))
val model = encoder.fit(df)
- val encoded = model.transform(df)
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector)](df, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
}
}
@@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite
.setInputCols(Array("size"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
- val output = model.transform(df)
- val group = AttributeGroup.fromStructField(output.schema("encoded"))
- assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+ testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows =>
+ val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+ }
}
test("input column without ML attribute") {
@@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite
.setInputCols(Array("index"))
.setOutputCols(Array("encoded"))
val model = encoder.fit(df)
- val output = model.transform(df)
- val group = AttributeGroup.fromStructField(output.schema("encoded"))
- assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+ testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows =>
+ val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
+ }
}
test("read/write") {
@@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite
val df = spark.createDataFrame(sc.parallelize(data), schema)
- val dfWithTypes = df
- .withColumn("shortInput", df("input").cast(ShortType))
- .withColumn("longInput", df("input").cast(LongType))
- .withColumn("intInput", df("input").cast(IntegerType))
- .withColumn("floatInput", df("input").cast(FloatType))
- .withColumn("decimalInput", df("input").cast(DecimalType(10, 0)))
-
- val cols = Array("input", "shortInput", "longInput", "intInput",
- "floatInput", "decimalInput")
- for (col <- cols) {
- val encoder = new OneHotEncoderEstimator()
- .setInputCols(Array(col))
+ class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Vector)])
+
+ val types = Seq(
+ new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()))
+
+ for (t <- types) {
+ val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected"))
+ val estimator = new OneHotEncoderEstimator()
+ .setInputCols(Array("input"))
.setOutputCols(Array("output"))
.setDropLast(false)
- val model = encoder.fit(dfWithTypes)
- val encoded = model.transform(dfWithTypes)
-
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
- }
+ val model = estimator.fit(dfWithTypes)
+ testTransformer(dfWithTypes, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
+ }(t.encoder)
}
}
@@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite
assert(encoder.getDropLast === false)
val model = encoder.fit(df)
- val encoded = model.transform(df)
- encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
- }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
- assert(vec1 === vec2)
- assert(vec3 === vec4)
+ testTransformer[(Double, Vector, Double, Vector)](
+ df,
+ model,
+ "output1",
+ "output2",
+ "expected1",
+ "expected2") {
+ case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
}
}
@@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite
.setOutputCols(Array("output1", "output2"))
val model = encoder.fit(df)
- val encoded = model.transform(df)
- encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3))
- }.collect().foreach { case (vec1, vec2, vec3, vec4) =>
- assert(vec1 === vec2)
- assert(vec3 === vec4)
+ testTransformer[(Double, Vector, Double, Vector)](
+ df,
+ model,
+ "output1",
+ "output2",
+ "expected1",
+ "expected2") {
+ case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) =>
+ assert(output1 === expected1)
+ assert(output2 === expected2)
}
}
@@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite
.setOutputCols(Array("encoded"))
val model = encoder.fit(trainingDF)
- val err = intercept[SparkException] {
- model.transform(testDF).show
- }
- err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+ testTransformerByInterceptingException[(Int, Int)](
+ testDF,
+ model,
+ expectedMessagePart = "Unseen value: 3.0. To handle unseen values",
+ firstResultCol = "encoded")
+
}
test("Can't transform on negative input") {
@@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite
.setOutputCols(Array("encoded"))
val model = encoder.fit(trainingDF)
- val err = intercept[SparkException] {
- model.transform(testDF).collect()
- }
- err.getMessage.contains("Negative value: -1.0. Input can't be negative")
+ testTransformerByInterceptingException[(Int, Int)](
+ testDF,
+ model,
+ expectedMessagePart = "Negative value: -1.0. Input can't be negative",
+ firstResultCol = "encoded")
}
test("Keep on invalid values: dropLast = false") {
@@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite
.setDropLast(false)
val model = encoder.fit(trainingDF)
- val encoded = model.transform(testDF)
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector)](testDF, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
}
}
@@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite
.setDropLast(true)
val model = encoder.fit(trainingDF)
- val encoded = model.transform(testDF)
- encoded.select("output", "expected").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector)](testDF, model, "output", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
}
}
@@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite
val model = encoder.fit(df)
model.setDropLast(false)
- val encoded1 = model.transform(df)
- encoded1.select("output", "expected1").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") {
+ case Row(output: Vector, expected1: Vector) =>
+ assert(output === expected1)
}
model.setDropLast(true)
- val encoded2 = model.transform(df)
- encoded2.select("output", "expected2").rdd.map { r =>
- (r.getAs[Vector](0), r.getAs[Vector](1))
- }.collect().foreach { case (vec1, vec2) =>
- assert(vec1 === vec2)
+ testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") {
+ case Row(output: Vector, expected2: Vector) =>
+ assert(output === expected2)
}
}
@@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite
val model = encoder.fit(trainingDF)
model.setHandleInvalid("error")
- val err = intercept[SparkException] {
- model.transform(testDF).collect()
- }
- err.getMessage.contains("Unseen value: 3.0. To handle unseen values")
+ testTransformerByInterceptingException[(Double, Vector)](
+ testDF,
+ model,
+ expectedMessagePart = "Unseen value: 3.0. To handle unseen values",
+ firstResultCol = "output")
model.setHandleInvalid("keep")
- model.transform(testDF).collect()
+ testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => }
}
test("Transforming on mismatched attributes") {
@@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite
val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large")
val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size")
.select(col("size").as("size", testAttr.toMetadata()))
- val err = intercept[Exception] {
- model.transform(testDF).collect()
- }
- err.getMessage.contains("OneHotEncoderModel expected 2 categorical values")
+ testTransformerByInterceptingException[(Double)](
+ testDF,
+ model,
+ expectedMessagePart = "OneHotEncoderModel expected 2 categorical values",
+ firstResultCol = "encoded")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index c44c6813a94be..41b32b2ffa096 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -17,18 +17,18 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
class OneHotEncoderSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -54,16 +54,19 @@ class OneHotEncoderSuite
assert(encoder.getDropLast === true)
encoder.setDropLast(false)
assert(encoder.getDropLast === false)
- val encoded = encoder.transform(transformed)
-
- val output = encoded.select("id", "labelVec").rdd.map { r =>
- val vec = r.getAs[Vector](1)
- (r.getInt(0), vec(0), vec(1), vec(2))
- }.collect().toSet
- // a -> 0, b -> 2, c -> 1
- val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
- (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
- assert(output === expected)
+ val expected = Seq(
+ (0, Vectors.sparse(3, Seq((0, 1.0)))),
+ (1, Vectors.sparse(3, Seq((2, 1.0)))),
+ (2, Vectors.sparse(3, Seq((1, 1.0)))),
+ (3, Vectors.sparse(3, Seq((0, 1.0)))),
+ (4, Vectors.sparse(3, Seq((0, 1.0)))),
+ (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected")
+
+ val withExpected = transformed.join(expected, "id")
+ testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
+ }
}
test("OneHotEncoder dropLast = true") {
@@ -71,16 +74,19 @@ class OneHotEncoderSuite
val encoder = new OneHotEncoder()
.setInputCol("labelIndex")
.setOutputCol("labelVec")
- val encoded = encoder.transform(transformed)
-
- val output = encoded.select("id", "labelVec").rdd.map { r =>
- val vec = r.getAs[Vector](1)
- (r.getInt(0), vec(0), vec(1))
- }.collect().toSet
- // a -> 0, b -> 2, c -> 1
- val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0),
- (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0))
- assert(output === expected)
+ val expected = Seq(
+ (0, Vectors.sparse(2, Seq((0, 1.0)))),
+ (1, Vectors.sparse(2, Seq())),
+ (2, Vectors.sparse(2, Seq((1, 1.0)))),
+ (3, Vectors.sparse(2, Seq((0, 1.0)))),
+ (4, Vectors.sparse(2, Seq((0, 1.0)))),
+ (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected")
+
+ val withExpected = transformed.join(expected, "id")
+ testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
+ }
}
test("input column with ML attribute") {
@@ -90,20 +96,22 @@ class OneHotEncoderSuite
val encoder = new OneHotEncoder()
.setInputCol("size")
.setOutputCol("encoded")
- val output = encoder.transform(df)
- val group = AttributeGroup.fromStructField(output.schema("encoded"))
- assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+ testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows =>
+ val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
+ assert(group.size === 2)
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
+ }
}
+
test("input column without ML attribute") {
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
val encoder = new OneHotEncoder()
.setInputCol("index")
.setOutputCol("encoded")
- val output = encoder.transform(df)
- val group = AttributeGroup.fromStructField(output.schema("encoded"))
+ val rows = encoder.transform(df).select("encoded").collect()
+ val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
@@ -119,29 +127,41 @@ class OneHotEncoderSuite
test("OneHotEncoder with varying types") {
val df = stringIndexed()
- val dfWithTypes = df
- .withColumn("shortLabel", df("labelIndex").cast(ShortType))
- .withColumn("longLabel", df("labelIndex").cast(LongType))
- .withColumn("intLabel", df("labelIndex").cast(IntegerType))
- .withColumn("floatLabel", df("labelIndex").cast(FloatType))
- .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0)))
- val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel",
- "floatLabel", "decimalLabel")
- for (col <- cols) {
+ val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
+ val expected = Seq(
+ (0, Vectors.sparse(3, Seq((0, 1.0)))),
+ (1, Vectors.sparse(3, Seq((2, 1.0)))),
+ (2, Vectors.sparse(3, Seq((1, 1.0)))),
+ (3, Vectors.sparse(3, Seq((0, 1.0)))),
+ (4, Vectors.sparse(3, Seq((0, 1.0)))),
+ (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected")
+
+ val withExpected = df.join(expected, "id")
+
+ class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Vector)])
+
+ val types = Seq(
+ new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()))
+
+ for (t <- types) {
+ val dfWithTypes = withExpected.select(col("labelIndex")
+ .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected"))
val encoder = new OneHotEncoder()
- .setInputCol(col)
+ .setInputCol("labelIndex")
.setOutputCol("labelVec")
.setDropLast(false)
- val encoded = encoder.transform(dfWithTypes)
-
- val output = encoded.select("id", "labelVec").rdd.map { r =>
- val vec = r.getAs[Vector](1)
- (r.getInt(0), vec(0), vec(1), vec(2))
- }.collect().toSet
- // a -> 0, b -> 2, c -> 1
- val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
- (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
- assert(output === expected)
+
+ testTransformer(dfWithTypes, encoder, "labelVec", "expected") {
+ case Row(output: Vector, expected: Vector) =>
+ assert(output === expected)
+ }(t.encoder)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 3067a52a4df76..531b1d7c4d9f7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
-class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class PCASuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val pcaModel = pca.fit(df)
MLTestingUtils.checkCopyAndUids(pca, pcaModel)
-
- pcaModel.transform(df).select("pca_features", "expected").collect().foreach {
- case Row(x: Vector, y: Vector) =>
- assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+ testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") {
+ case Row(result: Vector, expected: Vector) =>
+ assert(result ~== expected absTol 1e-5,
+ "Transformed vector is different with expected vector.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index e4b0ddf98bfad..0be7aa6c83f29 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -17,18 +17,13 @@
package org.apache.spark.ml.feature
-import org.scalatest.exceptions.TestFailedException
-
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -60,6 +55,18 @@ class PolynomialExpansionSuite
-1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0),
Vectors.sparse(19, Array.empty, Array.empty))
+ def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = {
+ assert((lhs, rhs) match {
+ case (v1: DenseVector, v2: DenseVector) => true
+ case (v1: SparseVector, v2: SparseVector) => true
+ case _ => false
+ }, "The vector type should be preserved after polynomial expansion.")
+ }
+
+ def assertValues(lhs: Vector, rhs: Vector): Unit = {
+ assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.")
+ }
+
test("Polynomial expansion with default parameter") {
val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected")
@@ -67,13 +74,10 @@ class PolynomialExpansionSuite
.setInputCol("features")
.setOutputCol("polyFeatures")
- polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
- case Row(expanded: DenseVector, expected: DenseVector) =>
- assert(expanded ~== expected absTol 1e-1)
- case Row(expanded: SparseVector, expected: SparseVector) =>
- assert(expanded ~== expected absTol 1e-1)
- case _ =>
- throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+ testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
+ case Row(expanded: Vector, expected: Vector) =>
+ assertTypeOfVector(expanded, expected)
+ assertValues(expanded, expected)
}
}
@@ -85,13 +89,10 @@ class PolynomialExpansionSuite
.setOutputCol("polyFeatures")
.setDegree(3)
- polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
- case Row(expanded: DenseVector, expected: DenseVector) =>
- assert(expanded ~== expected absTol 1e-1)
- case Row(expanded: SparseVector, expected: SparseVector) =>
- assert(expanded ~== expected absTol 1e-1)
- case _ =>
- throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+ testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
+ case Row(expanded: Vector, expected: Vector) =>
+ assertTypeOfVector(expanded, expected)
+ assertValues(expanded, expected)
}
}
@@ -103,11 +104,9 @@ class PolynomialExpansionSuite
.setOutputCol("polyFeatures")
.setDegree(1)
- polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
+ testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") {
case Row(expanded: Vector, expected: Vector) =>
- assert(expanded ~== expected absTol 1e-1)
- case _ =>
- throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+ assertValues(expanded, expected)
}
}
@@ -133,12 +132,13 @@ class PolynomialExpansionSuite
.setOutputCol("polyFeatures")
for (i <- Seq(10, 11)) {
- val transformed = t.setDegree(i)
- .transform(df)
- .select(s"expectedPoly${i}size", "polyFeatures")
- .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
-
- assert(transformed.collect.forall(identity))
+ testTransformer[(Vector, Int, Int)](
+ df,
+ t.setDegree(i),
+ s"expectedPoly${i}size",
+ "polyFeatures") { case Row(size: Int, expected: Vector) =>
+ assert(size === expected.size)
+ }
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index e9a75e931e6a8..b009038bbd833 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.Pipeline
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql._
-import org.apache.spark.sql.functions.udf
-class QuantileDiscretizerSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
test("Test observed number of buckets and their sizes match expected values") {
val spark = this.spark
@@ -38,19 +36,19 @@ class QuantileDiscretizerSuite
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)
- val result = discretizer.fit(df).transform(df)
-
- val observedNumBuckets = result.select("result").distinct.count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
+ val model = discretizer.fit(df)
- val relativeError = discretizer.getRelativeError
- val isGoodBucket = udf {
- (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
+ testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows =>
+ val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
+ val relativeError = discretizer.getRelativeError
+ val numGoodBuckets = result.groupBy("result").count
+ .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
}
- val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
- assert(numGoodBuckets === numBuckets,
- "Bucket sizes are not within expected relative error tolerance.")
}
test("Test on data with high proportion of duplicated values") {
@@ -65,11 +63,14 @@ class QuantileDiscretizerSuite
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)
- val result = discretizer.fit(df).transform(df)
- val observedNumBuckets = result.select("result").distinct.count
- assert(observedNumBuckets == expectedNumBuckets,
- s"Observed number of buckets are not correct." +
- s" Expected $expectedNumBuckets but found $observedNumBuckets")
+ val model = discretizer.fit(df)
+ testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows =>
+ val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets == expectedNumBuckets,
+ s"Observed number of buckets are not correct." +
+ s" Expected $expectedNumBuckets but found $observedNumBuckets")
+ }
}
test("Test transform on data with NaN value") {
@@ -88,17 +89,20 @@ class QuantileDiscretizerSuite
withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
val dataFrame: DataFrame = validData.toSeq.toDF("input")
- intercept[SparkException] {
- discretizer.fit(dataFrame).transform(dataFrame).collect()
- }
+ val model = discretizer.fit(dataFrame)
+ testTransformerByInterceptingException[(Double)](
+ dataFrame,
+ model,
+ expectedMessagePart = "Bucketizer encountered NaN value.",
+ firstResultCol = "result")
}
List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
case(u, v) =>
discretizer.setHandleInvalid(u)
val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected")
- val result = discretizer.fit(dataFrame).transform(dataFrame)
- result.select("result", "expected").collect().foreach {
+ val model = discretizer.fit(dataFrame)
+ testTransformer[(Double, Double)](dataFrame, model, "result", "expected") {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
@@ -117,14 +121,17 @@ class QuantileDiscretizerSuite
.setOutputCol("result")
.setNumBuckets(5)
- val result = discretizer.fit(trainDF).transform(testDF)
- val firstBucketSize = result.filter(result("result") === 0.0).count
- val lastBucketSize = result.filter(result("result") === 4.0).count
+ val model = discretizer.fit(trainDF)
+ testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows =>
+ val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result")
+ val firstBucketSize = result.filter(result("result") === 0.0).count
+ val lastBucketSize = result.filter(result("result") === 4.0).count
- assert(firstBucketSize === 30L,
- s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
- assert(lastBucketSize === 31L,
- s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
+ assert(firstBucketSize === 30L,
+ s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+ assert(lastBucketSize === 31L,
+ s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
+ }
}
test("read/write") {
@@ -132,7 +139,10 @@ class QuantileDiscretizerSuite
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setNumBuckets(6)
- testDefaultReadWrite(t)
+
+ val readDiscretizer = testDefaultReadWrite(t)
+ val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol")
+ readDiscretizer.fit(data)
}
test("Verify resulting model has parent") {
@@ -162,21 +172,24 @@ class QuantileDiscretizerSuite
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBuckets(numBuckets)
- val result = discretizer.fit(df).transform(df)
-
- val relativeError = discretizer.getRelativeError
- val isGoodBucket = udf {
- (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
- }
-
- for (i <- 1 to 2) {
- val observedNumBuckets = result.select("result" + i).distinct.count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
-
- val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count
- assert(numGoodBuckets === numBuckets,
- "Bucket sizes are not within expected relative error tolerance.")
+ val model = discretizer.fit(df)
+ testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows =>
+ val result =
+ rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2")
+ val relativeError = discretizer.getRelativeError
+ for (i <- 1 to 2) {
+ val observedNumBuckets = result.select("result" + i).distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
+
+ val numGoodBuckets = result
+ .groupBy("result" + i)
+ .count
+ .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}")
+ .count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
+ }
}
}
@@ -193,12 +206,16 @@ class QuantileDiscretizerSuite
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBuckets(numBuckets)
- val result = discretizer.fit(df).transform(df)
- for (i <- 1 to 2) {
- val observedNumBuckets = result.select("result" + i).distinct.count
- assert(observedNumBuckets == expectedNumBucket,
- s"Observed number of buckets are not correct." +
- s" Expected $expectedNumBucket but found ($observedNumBuckets")
+ val model = discretizer.fit(df)
+ testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows =>
+ val result =
+ rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2")
+ for (i <- 1 to 2) {
+ val observedNumBuckets = result.select("result" + i).distinct.count
+ assert(observedNumBuckets == expectedNumBucket,
+ s"Observed number of buckets are not correct." +
+ s" Expected $expectedNumBucket but found ($observedNumBuckets")
+ }
}
}
@@ -221,9 +238,12 @@ class QuantileDiscretizerSuite
withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2")
- intercept[SparkException] {
- discretizer.fit(dataFrame).transform(dataFrame).collect()
- }
+ val model = discretizer.fit(dataFrame)
+ testTransformerByInterceptingException[(Double, Double)](
+ dataFrame,
+ model,
+ expectedMessagePart = "Bucketizer encountered NaN value.",
+ firstResultCol = "result1")
}
List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach {
@@ -232,8 +252,14 @@ class QuantileDiscretizerSuite
val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map {
case (((a, b), c), d) => (a, b, c, d)
}.toSeq.toDF("input1", "input2", "expected1", "expected2")
- val result = discretizer.fit(dataFrame).transform(dataFrame)
- result.select("result1", "expected1", "result2", "expected2").collect().foreach {
+ val model = discretizer.fit(dataFrame)
+ testTransformer[(Double, Double, Double, Double)](
+ dataFrame,
+ model,
+ "result1",
+ "expected1",
+ "result2",
+ "expected2") {
case Row(x: Double, y: Double, z: Double, w: Double) =>
assert(x === y && w === z)
}
@@ -265,9 +291,16 @@ class QuantileDiscretizerSuite
.setOutputCols(Array("result1", "result2", "result3"))
.setNumBucketsArray(numBucketsArray)
- discretizer.fit(df).transform(df).
- select("result1", "expected1", "result2", "expected2", "result3", "expected3")
- .collect().foreach {
+ val model = discretizer.fit(df)
+ testTransformer[(Double, Double, Double, Double, Double, Double)](
+ df,
+ model,
+ "result1",
+ "expected1",
+ "result2",
+ "expected2",
+ "result3",
+ "expected3") {
case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) =>
assert(r1 === e1,
s"The result value is not correct after bucketing. Expected $e1 but found $r1")
@@ -319,20 +352,16 @@ class QuantileDiscretizerSuite
.setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3))
.fit(df)
- val resultForMultiCols = plForMultiCols.transform(df)
- .select("result1", "result2", "result3")
- .collect()
-
- val resultForSingleCol = plForSingleCol.transform(df)
- .select("result1", "result2", "result3")
- .collect()
+ val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect()
- resultForSingleCol.zip(resultForMultiCols).foreach {
- case (rowForSingle, rowForMultiCols) =>
- assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
- rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
- rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2))
- }
+ testTransformerByGlobalCheckFunc[(Double, Double, Double)](
+ df,
+ plForMultiCols,
+ "result1",
+ "result2",
+ "result3") { rows =>
+ assert(rows === expected)
+ }
}
test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " +
@@ -359,18 +388,16 @@ class QuantileDiscretizerSuite
.setOutputCols(Array("result1", "result2", "result3"))
.setNumBucketsArray(Array(10, 10, 10))
- val result1 = discretizerSingleNumBuckets.fit(df).transform(df)
- .select("result1", "result2", "result3")
- .collect()
- val result2 = discretizerNumBucketsArray.fit(df).transform(df)
- .select("result1", "result2", "result3")
- .collect()
-
- result1.zip(result2).foreach {
- case (row1, row2) =>
- assert(row1.getDouble(0) == row2.getDouble(0) &&
- row1.getDouble(1) == row2.getDouble(1) &&
- row1.getDouble(2) == row2.getDouble(2))
+ val model = discretizerSingleNumBuckets.fit(df)
+ val expected = model.transform(df).select("result1", "result2", "result3").collect()
+
+ testTransformerByGlobalCheckFunc[(Double, Double, Double)](
+ df,
+ discretizerNumBucketsArray.fit(df),
+ "result1",
+ "result2",
+ "result3") { rows =>
+ assert(rows === expected)
}
}
@@ -379,7 +406,12 @@ class QuantileDiscretizerSuite
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("result1", "result2"))
.setNumBucketsArray(Array(5, 10))
- testDefaultReadWrite(discretizer)
+
+ val readDiscretizer = testDefaultReadWrite(discretizer)
+ val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2")
+ readDiscretizer.fit(data)
+ assert(discretizer.hasDefault(discretizer.outputCol))
+ assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
}
test("Multiple Columns: Both inputCol and inputCols are set") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index bfe38d32dd77d..a250331efeb1d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -32,10 +32,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
def testRFormulaTransform[A: Encoder](
dataframe: DataFrame,
formulaModel: RFormulaModel,
- expected: DataFrame): Unit = {
+ expected: DataFrame,
+ expectedAttributes: AttributeGroup*): Unit = {
+ val resultSchema = formulaModel.transformSchema(dataframe.schema)
+ assert(resultSchema.json === expected.schema.json)
+ assert(resultSchema === expected.schema)
val (first +: rest) = expected.schema.fieldNames.toSeq
val expectedRows = expected.collect()
testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows =>
+ assert(rows.head.schema.toString() == resultSchema.toString())
+ for (expectedAttributeGroup <- expectedAttributes) {
+ val attributeGroup =
+ AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name))
+ assert(attributeGroup === expectedAttributeGroup)
+ }
assert(rows === expectedRows)
}
}
@@ -49,15 +59,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
MLTestingUtils.checkCopyAndUids(formula, model)
- val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
(0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
(2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)
).toDF("id", "v1", "v2", "features", "label")
- // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
- assert(result.schema.toString == resultSchema.toString)
- assert(resultSchema == expected.schema)
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}
@@ -73,9 +78,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y")
val model = formula.fit(original)
+ val expected = Seq(
+ (0, 1.0, Vectors.dense(0.0)),
+ (2, 2.0, Vectors.dense(2.0))
+ ).toDF("x", "y", "features")
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
- assert(resultSchema.toString == model.transform(original).schema.toString)
+ testRFormulaTransform[(Int, Double)](original, model, expected)
}
test("label column already exists but forceIndexLabel was set with true") {
@@ -93,9 +102,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
intercept[IllegalArgumentException] {
model.transformSchema(original.schema)
}
- intercept[IllegalArgumentException] {
- model.transform(original)
- }
+ testTransformerByInterceptingException[(Int, Boolean)](
+ original,
+ model,
+ "Label column already exists and is not of type NumericType.",
+ "x")
}
test("allow missing label column for test datasets") {
@@ -105,21 +116,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
assert(!resultSchema.exists(_.name == "label"))
- assert(resultSchema.toString == model.transform(original).schema.toString)
+ val expected = Seq(
+ (0, 1.0, Vectors.dense(0.0)),
+ (2, 2.0, Vectors.dense(2.0))
+ ).toDF("x", "_not_y", "features")
+ testRFormulaTransform[(Int, Double)](original, model, expected)
}
test("allow empty label") {
val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b")
val formula = new RFormula().setFormula("~ a + b")
val model = formula.fit(original)
- val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
(1, 2.0, 3.0, Vectors.dense(2.0, 3.0)),
(4, 5.0, 6.0, Vectors.dense(5.0, 6.0)),
(7, 8.0, 9.0, Vectors.dense(8.0, 9.0))
).toDF("id", "a", "b", "features")
- assert(result.schema.toString == resultSchema.toString)
testRFormulaTransform[(Int, Double, Double)](original, model, expected)
}
@@ -128,15 +140,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
(4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label")
- assert(result.schema.toString == resultSchema.toString)
testRFormulaTransform[(Int, String, Int)](original, model, expected)
}
@@ -175,9 +184,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
var idx = 0
for (orderType <- StringIndexer.supportedStringOrderType) {
val model = formula.setStringIndexerOrderType(orderType).fit(original)
- val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
- assert(result.schema.toString == resultSchema.toString)
testRFormulaTransform[(Int, String, Int)](original, model, expected(idx))
idx += 1
}
@@ -218,9 +224,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
).toDF("id", "a", "b", "features", "label")
val model = formula.fit(original)
- val result = model.transform(original)
- val resultSchema = model.transformSchema(original.schema)
- assert(result.schema.toString == resultSchema.toString)
testRFormulaTransform[(Int, String, Int)](original, model, expected)
}
@@ -254,19 +257,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val formula1 = new RFormula().setFormula("id ~ a + b + c - 1")
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
val model1 = formula1.fit(original)
- val result1 = model1.transform(original)
- val resultSchema1 = model1.transformSchema(original.schema)
- // Note the column order is different between R and Spark.
- val expected1 = Seq(
- (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0),
- (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
- (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
- (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
- ).toDF("id", "a", "b", "c", "features", "label")
- assert(result1.schema.toString == resultSchema1.toString)
- testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1)
-
- val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
val expectedAttrs1 = new AttributeGroup(
"features",
Array[Attribute](
@@ -275,14 +265,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
new BinaryAttribute(Some("a_bar"), Some(3)),
new BinaryAttribute(Some("b_zz"), Some(4)),
new NumericAttribute(Some("c"), Some(5))))
- assert(attrs1 === expectedAttrs1)
+ // Note the column order is different between R and Spark.
+ val expected1 = Seq(
+ (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0),
+ (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
+ (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
+ (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
+ ).toDF("id", "a", "b", "c", "features", "label")
+
+ testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1)
// There is no impact for string terms interaction.
val formula2 = new RFormula().setFormula("id ~ a:b + c - 1")
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
val model2 = formula2.fit(original)
- val result2 = model2.transform(original)
- val resultSchema2 = model2.transformSchema(original.schema)
// Note the column order is different between R and Spark.
val expected2 = Seq(
(1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0),
@@ -290,10 +286,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0),
(4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
- assert(result2.schema.toString == resultSchema2.toString)
- testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2)
-
- val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
val expectedAttrs2 = new AttributeGroup(
"features",
Array[Attribute](
@@ -304,7 +296,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
new NumericAttribute(Some("a_bar:b_zz"), Some(5)),
new NumericAttribute(Some("a_bar:b_zq"), Some(6)),
new NumericAttribute(Some("c"), Some(7))))
- assert(attrs2 === expectedAttrs2)
+
+ testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2)
}
test("index string label") {
@@ -313,13 +306,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
+ val attr = NominalAttribute.defaultAttr
val expected = Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
).toDF("id", "a", "b", "features", "label")
- // assert(result.schema.toString == resultSchema.toString)
+ .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
testRFormulaTransform[(String, String, Int)](original, model, expected)
}
@@ -329,13 +323,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
- val expected = spark.createDataFrame(
- Seq(
+ val attr = NominalAttribute.defaultAttr
+ val expected = Seq(
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
- ).toDF("id", "a", "b", "features", "label")
+ .toDF("id", "a", "b", "features", "label")
+ .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
testRFormulaTransform[(Double, String, Int)](original, model, expected)
}
@@ -344,15 +339,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expected = Seq(
+ (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+ (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+ (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+ .toDF("id", "a", "b", "features", "label")
val expectedAttrs = new AttributeGroup(
"features",
Array(
new BinaryAttribute(Some("a_bar"), Some(1)),
new BinaryAttribute(Some("a_foo"), Some(2)),
new NumericAttribute(Some("b"), Some(3))))
- assert(attrs === expectedAttrs)
+ testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs)
+
}
test("vector attribute generation") {
@@ -360,14 +360,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
.toDF("id", "vec")
val model = formula.fit(original)
- val result = model.transform(original)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val attrs = new AttributeGroup("vec", 2)
+ val expected = Seq(
+ (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0),
+ (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0))
+ .toDF("id", "vec", "features", "label")
+ .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label")
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec_0"), Some(1)),
new NumericAttribute(Some("vec_1"), Some(2))))
- assert(attrs === expectedAttrs)
+
+ testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs)
}
test("vector attribute generation with unnamed input attrs") {
@@ -381,31 +386,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
NumericAttribute.defaultAttr)).toMetadata()
val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
val model = formula.fit(original)
- val result = model.transform(original)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expected = Seq(
+ (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0),
+ (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)
+ ).toDF("id", "vec2", "features", "label")
+ .select($"id", $"vec2".as("vec2", metadata), $"features", $"label")
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("vec2_0"), Some(1)),
new NumericAttribute(Some("vec2_1"), Some(2))))
- assert(attrs === expectedAttrs)
+ testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs)
}
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d")
val model = formula.fit(original)
- val result = model.transform(original)
val expected = Seq(
(1, 2, 4, 2, Vectors.dense(16.0), 1.0),
(2, 3, 4, 1, Vectors.dense(12.0), 2.0)
).toDF("a", "b", "c", "d", "features", "label")
- testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
- assert(attrs === expectedAttrs)
+ testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs)
}
test("factor numeric interaction") {
@@ -414,7 +419,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
.toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
val expected = Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
@@ -423,15 +427,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
(4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)
).toDF("id", "a", "b", "features", "label")
- testRFormulaTransform[(Int, String, Int)](original, model, expected)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("a_baz:b"), Some(1)),
new NumericAttribute(Some("a_bar:b"), Some(2)),
new NumericAttribute(Some("a_foo:b"), Some(3))))
- assert(attrs === expectedAttrs)
+ testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs)
}
test("factor factor interaction") {
@@ -439,14 +441,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
val original =
Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
val model = formula.fit(original)
- val result = model.transform(original)
val expected = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
(3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")
testRFormulaTransform[(Int, String, String)](original, model, expected)
- val attrs = AttributeGroup.fromStructField(result.schema("features"))
val expectedAttrs = new AttributeGroup(
"features",
Array[Attribute](
@@ -454,7 +454,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
new NumericAttribute(Some("a_bar:b_zz"), Some(2)),
new NumericAttribute(Some("a_foo:b_zq"), Some(3)),
new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
- assert(attrs === expectedAttrs)
+ testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs)
}
test("read/write: RFormula") {
@@ -517,9 +517,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
// Handle unseen features.
val formula1 = new RFormula().setFormula("id ~ a + b")
- intercept[SparkException] {
- formula1.fit(df1).transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String, String)](
+ df2,
+ formula1.fit(df1),
+ "Unseen label:",
+ "features")
val model1 = formula1.setHandleInvalid("skip").fit(df1)
val model2 = formula1.setHandleInvalid("keep").fit(df1)
@@ -538,21 +540,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
// Handle unseen labels.
val formula2 = new RFormula().setFormula("b ~ a + id")
- intercept[SparkException] {
- formula2.fit(df1).transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String, String)](
+ df2,
+ formula2.fit(df1),
+ "Unseen label:",
+ "label")
+
val model3 = formula2.setHandleInvalid("skip").fit(df1)
val model4 = formula2.setHandleInvalid("keep").fit(df1)
+ val attr = NominalAttribute.defaultAttr
val expected3 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
(2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
).toDF("id", "a", "b", "features", "label")
+ .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
+
val expected4 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
).toDF("id", "a", "b", "features", "label")
+ .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata()))
testRFormulaTransform[(Int, String, String)](df2, model3, expected3)
testRFormulaTransform[(Int, String, String)](df2, model4, expected4)
@@ -584,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
assert(features.toArray === a +: b.toArray)
}
}
+
+ test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") {
+ val d1 = Seq(
+ (1001L, "a"),
+ (1002L, "b")).toDF("id1", "c1")
+ val d2 = Seq[(java.lang.Long, String)](
+ (20001L, "x"),
+ (20002L, "y"),
+ (null, null)).toDF("id2", "c2")
+ val dataset = d1.crossJoin(d2)
+
+ def get_output(mode: String): DataFrame = {
+ val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode)
+ formula.fit(dataset).transform(dataset).select("features", "label")
+ }
+
+ assert(intercept[SparkException](get_output("error").collect())
+ .getMessage.contains("Encountered null while assembling a row"))
+ assert(get_output("skip").count() == 4)
+ assert(get_output("keep").count() == 6)
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 673a146e619f2..cf09418d8e0a2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -17,15 +17,12 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
-class SQLTransformerSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class SQLTransformerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -37,14 +34,22 @@ class SQLTransformerSuite
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val sqlTrans = new SQLTransformer().setStatement(
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
- val result = sqlTrans.transform(original)
- val resultSchema = sqlTrans.transformSchema(original.schema)
- val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))
+ val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))
.toDF("id", "v1", "v2", "v3", "v4")
- assert(result.schema.toString == resultSchema.toString)
- assert(resultSchema == expected.schema)
- assert(result.collect().toSeq == expected.collect().toSeq)
- assert(original.sparkSession.catalog.listTables().count() == 0)
+ val resultSchema = sqlTrans.transformSchema(original.schema)
+ testTransformerByGlobalCheckFunc[(Int, Double, Double)](
+ original,
+ sqlTrans,
+ "id",
+ "v1",
+ "v2",
+ "v3",
+ "v4") { rows =>
+ assert(rows.head.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(rows == expected.collect().toSeq)
+ assert(original.sparkSession.catalog.listTables().count() == 0)
+ }
}
test("read/write") {
@@ -63,13 +68,13 @@ class SQLTransformerSuite
}
test("SPARK-22538: SQLTransformer should not unpersist given dataset") {
- val df = spark.range(10)
+ val df = spark.range(10).toDF()
df.cache()
df.count()
assert(df.storageLevel != StorageLevel.NONE)
- new SQLTransformer()
+ val sqlTrans = new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
- .transform(df)
+ testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => }
assert(df.storageLevel != StorageLevel.NONE)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 350ba44baa1eb..c5c49d67194e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -17,16 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
+class StandardScalerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
)
}
- def assertResult(df: DataFrame): Unit = {
- df.select("standardized_features", "expected").collect().foreach {
- case Row(vector1: Vector, vector2: Vector) =>
- assert(vector1 ~== vector2 absTol 1E-5,
- "The vector value is not correct after standardization.")
- }
+ def assertResult: Row => Unit = {
+ case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1 ~== vector2 absTol 1E-5,
+ "The vector value is not correct after standardization.")
}
test("params") {
@@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
val standardScaler0 = standardScalerEst0.fit(df0)
MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0)
- assertResult(standardScaler0.transform(df0))
+ testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")(
+ assertResult)
}
test("Standardization with setter") {
@@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
.setWithStd(false)
.fit(df3)
- assertResult(standardScaler1.transform(df1))
- assertResult(standardScaler2.transform(df2))
- assertResult(standardScaler3.transform(df3))
+ testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")(
+ assertResult)
+ testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")(
+ assertResult)
+ testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")(
+ assertResult)
}
test("sparse data and withMean") {
@@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
.setWithMean(true)
.setWithStd(false)
.fit(df)
- assertResult(standardScaler.transform(df))
+ testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")(
+ assertResult)
}
test("StandardScaler read/write") {
@@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
assert(newInstance.std === instance.std)
assert(newInstance.mean === instance.mean)
}
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index 5262b146b184e..20972d1f403b9 100755
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -17,28 +17,20 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
-
-object StopWordsRemoverSuite extends SparkFunSuite {
- def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
- t.transform(dataset)
- .select("filtered", "expected")
- .collect()
- .foreach { case Row(tokens, wantedTokens) =>
- assert(tokens === wantedTokens)
- }
- }
-}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
-class StopWordsRemoverSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest {
- import StopWordsRemoverSuite._
import testImplicits._
+ def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = {
+ testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") {
+ case Row(tokens: Seq[_], wantedTokens: Seq[_]) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+
test("StopWordsRemover default") {
val remover = new StopWordsRemover()
.setInputCol("raw")
@@ -73,6 +65,57 @@ class StopWordsRemoverSuite
testStopWordsRemover(remover, dataSet)
}
+ test("StopWordsRemover with localed input (case insensitive)") {
+ val stopWords = Array("milk", "cookie")
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setStopWords(stopWords)
+ .setCaseSensitive(false)
+ .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
+ val dataSet = Seq(
+ // scalastyle:off
+ (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")),
+ // scalastyle:on
+ (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
+ (Seq(null), Seq(null)),
+ (Seq(), Seq())
+ ).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover with localed input (case sensitive)") {
+ val stopWords = Array("milk", "cookie")
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setStopWords(stopWords)
+ .setCaseSensitive(true)
+ .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's.
+ val dataSet = Seq(
+ // scalastyle:off
+ (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")),
+ // scalastyle:on
+ (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")),
+ (Seq(null), Seq(null)),
+ (Seq(), Seq())
+ ).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover with invalid locale") {
+ intercept[IllegalArgumentException] {
+ val stopWords = Array("test", "a", "an", "the")
+ new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setStopWords(stopWords)
+ .setLocale("rt") // invalid locale
+ }
+ }
+
test("StopWordsRemover case sensitive") {
val remover = new StopWordsRemover()
.setInputCol("raw")
@@ -151,9 +194,10 @@ class StopWordsRemoverSuite
.setOutputCol(outputCol)
val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol)
- val thrown = intercept[IllegalArgumentException] {
- testStopWordsRemover(remover, dataSet)
- }
- assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
+ testTransformerByInterceptingException[(Array[String], Array[String])](
+ dataSet,
+ remover,
+ s"requirement failed: Column $outputCol already exists.",
+ "expected")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 775a04d3df050..df24367177011 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,17 +17,14 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
-class StringIndexerSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -46,19 +43,23 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
val indexerModel = indexer.fit(df)
-
MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
-
- val transformed = indexerModel.transform(df)
- val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attr.values.get === Array("a", "c", "b"))
- val output = transformed.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
// a -> 0, b -> 2, c -> 1
- val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
- assert(output === expected)
+ val expected = Seq(
+ (0, 0.0),
+ (1, 2.0),
+ (2, 1.0),
+ (3, 0.0),
+ (4, 0.0),
+ (5, 1.0)
+ ).toDF("id", "labelIndex")
+
+ testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows =>
+ val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("a", "c", "b"))
+ assert(rows.seq === expected.collect().toSeq)
+ }
}
test("StringIndexerUnseen") {
@@ -70,36 +71,38 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
+
// Verify we throw by default with unseen values
- intercept[SparkException] {
- indexer.transform(df2).collect()
- }
+ testTransformerByInterceptingException[(Int, String)](
+ df2,
+ indexer,
+ "Unseen label:",
+ "labelIndex")
- indexer.setHandleInvalid("skip")
// Verify that we skip the c record
- val transformedSkip = indexer.transform(df2)
- val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrSkip.values.get === Array("b", "a"))
- val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
// a -> 1, b -> 0
- val expectedSkip = Set((0, 1.0), (1, 0.0))
- assert(outputSkip === expectedSkip)
+ indexer.setHandleInvalid("skip")
+
+ val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows =>
+ val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ assert(rows.seq === expectedSkip.collect().toSeq)
+ }
indexer.setHandleInvalid("keep")
- // Verify that we keep the unseen records
- val transformedKeep = indexer.transform(df2)
- val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrKeep.values.get === Array("b", "a", "__unknown"))
- val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
+
// a -> 1, b -> 0, c -> 2, d -> 3
- val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
- assert(outputKeep === expectedKeep)
+ val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF()
+
+ // Verify that we keep the unseen records
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows =>
+ val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ assert(rows === expectedKeep.collect().toSeq)
+ }
}
test("StringIndexer with a numeric input column") {
@@ -109,16 +112,14 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
- val transformed = indexer.transform(df)
- val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attr.values.get === Array("100", "300", "200"))
- val output = transformed.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
// 100 -> 0, 200 -> 2, 300 -> 1
- val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
- assert(output === expected)
+ val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF()
+ testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows =>
+ val attr = Attribute.fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("100", "300", "200"))
+ assert(rows === expected.collect().toSeq)
+ }
}
test("StringIndexer with NULLs") {
@@ -133,37 +134,36 @@ class StringIndexerSuite
withClue("StringIndexer should throw error when setHandleInvalid=error " +
"when given NULL values") {
- intercept[SparkException] {
- indexer.setHandleInvalid("error")
- indexer.fit(df).transform(df2).collect()
- }
+ indexer.setHandleInvalid("error")
+ testTransformerByInterceptingException[(Int, String)](
+ df2,
+ indexer.fit(df),
+ "StringIndexer encountered NULL value.",
+ "labelIndex")
}
indexer.setHandleInvalid("skip")
- val transformedSkip = indexer.fit(df).transform(df2)
- val attrSkip = Attribute
- .fromStructField(transformedSkip.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrSkip.values.get === Array("b", "a"))
- val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
+ val modelSkip = indexer.fit(df)
// a -> 1, b -> 0
- val expectedSkip = Set((0, 1.0), (1, 0.0))
- assert(outputSkip === expectedSkip)
+ val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF()
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows =>
+ val attrSkip =
+ Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ assert(rows === expectedSkip.collect().toSeq)
+ }
indexer.setHandleInvalid("keep")
- val transformedKeep = indexer.fit(df).transform(df2)
- val attrKeep = Attribute
- .fromStructField(transformedKeep.schema("labelIndex"))
- .asInstanceOf[NominalAttribute]
- assert(attrKeep.values.get === Array("b", "a", "__unknown"))
- val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
// a -> 1, b -> 0, null -> 2
- val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
- assert(outputKeep === expectedKeep)
+ val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF()
+ val modelKeep = indexer.fit(df)
+ testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows =>
+ val attrKeep = Attribute
+ .fromStructField(rows.head.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ assert(rows === expectedKeep.collect().toSeq)
+ }
}
test("StringIndexerModel should keep silent if the input column does not exist.") {
@@ -171,7 +171,9 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
val df = spark.range(0L, 10L).toDF()
- assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
+ testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows =>
+ assert(rows.toSet === df.collect().toSet)
+ }
}
test("StringIndexerModel can't overwrite output column") {
@@ -188,9 +190,12 @@ class StringIndexerSuite
.setOutputCol("indexedInput")
.fit(df)
- intercept[IllegalArgumentException] {
- indexer.setOutputCol("output").transform(df)
- }
+ testTransformerByInterceptingException[(Int, String)](
+ df,
+ indexer.setOutputCol("output"),
+ "Output column output already exists.",
+ "labelIndex")
+
}
test("StringIndexer read/write") {
@@ -223,7 +228,8 @@ class StringIndexerSuite
.setInputCol("index")
.setOutputCol("actual")
.setLabels(labels)
- idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+
+ testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") {
case Row(actual, expected) =>
assert(actual === expected)
}
@@ -234,7 +240,8 @@ class StringIndexerSuite
val idxToStr1 = new IndexToString()
.setInputCol("indexWithAttr")
.setOutputCol("actual")
- idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+
+ testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") {
case Row(actual, expected) =>
assert(actual === expected)
}
@@ -252,9 +259,10 @@ class StringIndexerSuite
.setInputCol("labelIndex")
.setOutputCol("sameLabel")
.setLabels(indexer.labels)
- idx2str.transform(transformed).select("label", "sameLabel").collect().foreach {
- case Row(a: String, b: String) =>
- assert(a === b)
+
+ testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") {
+ case Row(sameLabel, label) =>
+ assert(sameLabel === label)
}
}
@@ -286,10 +294,11 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
- val transformed = indexer.transform(df)
- val attrs =
- NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
- assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+ testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows =>
+ val attrs =
+ NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true)
+ assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+ }
}
test("StringIndexer order types") {
@@ -299,18 +308,17 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
- val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
- Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
- Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
- Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
+ val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
+ Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
+ Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
+ Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
var idx = 0
for (orderType <- StringIndexer.supportedStringOrderType) {
- val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df)
- val output = transformed.select("id", "labelIndex").rdd.map { r =>
- (r.getInt(0), r.getDouble(1))
- }.collect().toSet
- assert(output === expected(idx))
+ val model = indexer.setStringOrderType(orderType).fit(df)
+ testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows =>
+ assert(rows === expected(idx).toDF().collect().toSeq)
+ }
idx += 1
}
}
@@ -328,7 +336,11 @@ class StringIndexerSuite
.setOutputCol("CITYIndexed")
.fit(dfNoBristol)
- val dfWithIndex = model.transform(dfNoBristol)
- assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1)
+ testTransformerByGlobalCheckFunc[(String, String, String)](
+ dfNoBristol,
+ model,
+ "CITYIndexed") { rows =>
+ assert(rows.toList.count(_.getDouble(0) == 1.0) === 1)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index c895659a2d8be..be59b0af2c78e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class TokenizerSuite extends MLTest with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Tokenizer)
@@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
}
-class RegexTokenizerSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.RegexTokenizerSuite._
import testImplicits._
+ def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = {
+ testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") {
+ case Row(tokens, wantedTokens) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new RegexTokenizer)
}
@@ -105,14 +108,3 @@ class RegexTokenizerSuite
}
}
-object RegexTokenizerSuite extends SparkFunSuite {
-
- def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
- t.transform(dataset)
- .select("tokens", "wantedTokens")
- .collect()
- .foreach { case Row(tokens, wantedTokens) =>
- assert(tokens === wantedTokens)
- }
- }
-}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index eca065f7e775d..91fb24a268b8c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
class VectorAssemblerSuite
@@ -31,30 +31,49 @@ class VectorAssemblerSuite
import testImplicits._
+ @transient var dfWithNullsAndNaNs: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val sv = Vectors.sparse(2, Array(1), Array(3.0))
+ dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)](
+ (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (2, 1, 0.0, null, "a", sv, 6L, null),
+ (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null),
+ (4, 4, null, null, "a", sv, 9L, null),
+ (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null),
+ (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null))
+ .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls")
+ }
+
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
}
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
- assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
+ assert(assemble(Array(1), keepInvalid = true)(0.0)
+ === Vectors.sparse(1, Array.empty, Array.empty))
+ assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0)
+ === Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0)
- assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
+ assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) ===
+ Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
- assert(assemble(0.0, dv, 1.0, sv) ===
+ assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
- for (v <- Seq(1, "a", null)) {
- intercept[SparkException](assemble(v))
- intercept[SparkException](assemble(1.0, v))
+ for (v <- Seq(1, "a")) {
+ intercept[SparkException](assemble(Array(1), keepInvalid = true)(v))
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v))
}
}
test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
- val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
+ val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
- val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
+ val sv = Vectors.sparse(1, Array(0), Array(4.0))
+ val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv)
assert(v2.isInstanceOf[DenseVector])
}
@@ -147,4 +166,94 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}
+
+ test("assemble should keep nulls when keepInvalid is true") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
+ assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null)
+ === Vectors.dense(1.0, Double.NaN, Double.NaN))
+ assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN))
+ assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN))
+ }
+
+ test("assemble should throw errors when keepInvalid is false") {
+ import org.apache.spark.ml.feature.VectorAssembler.assemble
+ intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null))
+ intercept[SparkException](assemble(Array(1), keepInvalid = false)(null))
+ intercept[SparkException](assemble(Array(2), keepInvalid = false)(null))
+ }
+
+ test("get lengths functions") {
+ import org.apache.spark.ml.feature.VectorAssembler._
+ val df = dfWithNullsAndNaNs
+ assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2))
+ assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y")))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"),
+ Seq("y"))).getMessage.contains("VectorSizeHint"))
+
+ assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2))
+ assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID))
+ .getMessage.contains("VectorSizeHint"))
+ }
+
+ test("Handle Invalid should behave properly") {
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z", "n"))
+ .setOutputCol("features")
+
+ def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = {
+ val attributeY = new AttributeGroup("y", 2)
+ val attributeZ = new AttributeGroup(
+ "z",
+ Array[Attribute](
+ NumericAttribute.defaultAttr.withName("foo"),
+ NumericAttribute.defaultAttr.withName("bar")))
+ val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata())
+ .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter)
+ val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata)
+ output.collect()
+ output
+ }
+
+ def runWithFirstRow(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs)
+ output.collect()
+ output
+ }
+
+ def runWithAllNullVectors(mode: String): Dataset[_] = {
+ val output = assembler.setHandleInvalid(mode)
+ .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2"))
+ output.collect()
+ output
+ }
+
+ // behavior when vector size hint is given
+ assert(runWithMetadata("keep").count() == 6, "should keep all rows")
+ assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls")
+ // should throw error with nulls
+ intercept[SparkException](runWithMetadata("error"))
+ // should throw error with NaNs
+ intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4"))
+
+ // behavior when first row has information
+ assert(intercept[RuntimeException](runWithFirstRow("keep").count())
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls")
+ intercept[SparkException](runWithFirstRow("error"))
+
+ // behavior when vector column is all null
+ assert(intercept[RuntimeException](runWithAllNullVectors("skip"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+ assert(intercept[NullPointerException](runWithAllNullVectors("error"))
+ .getMessage.contains("VectorSizeHint"), "should suggest to use metadata")
+
+ // behavior when scalar column is all null
+ assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4)
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 69a7b75e32eb7..e5675e31bbecf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.ml.feature
import scala.beans.{BeanInfo, BeanProperty}
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
-class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest with Logging {
+class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
import testImplicits._
import VectorIndexerSuite.FeatureData
@@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
MLTestingUtils.checkCopyAndUids(vectorIndexer, model)
- model.transform(densePoints1) // should work
- model.transform(sparsePoints1) // should work
+ testTransformer[FeatureData](densePoints1, model, "indexed") { _ => }
+ testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => }
+
// If the data is local Dataset, it throws AssertionError directly.
- intercept[AssertionError] {
- model.transform(densePoints2).collect()
- logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
+ withClue("Did not throw error when fit, transform were called on " +
+ "vectors of different lengths") {
+ testTransformerByInterceptingException[FeatureData](
+ densePoints2,
+ model,
+ "VectorIndexerModel expected vector of length 3 but found length 4",
+ "indexed")
}
// If the data is distributed Dataset, it throws SparkException
// which is the wrapper of AssertionError.
- intercept[SparkException] {
- model.transform(densePoints2.repartition(2)).collect()
- logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
+ withClue("Did not throw error when fit, transform were called " +
+ "on vectors of different lengths") {
+ testTransformerByInterceptingException[FeatureData](
+ densePoints2.repartition(2),
+ model,
+ "VectorIndexerModel expected vector of length 3 but found length 4",
+ "indexed")
}
intercept[SparkException] {
vectorIndexer.fit(badPoints)
@@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
val categoryMaps = model.categoryMaps
// Chose correct categorical features
assert(categoryMaps.keys.toSet === categoricalFeatures)
- val transformed = model.transform(data).select("indexed")
- val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0))
- val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed"))
- assert(featureAttrs.name === "indexed")
- assert(featureAttrs.attributes.get.length === model.numFeatures)
- categoricalFeatures.foreach { feature: Int =>
- val origValueSet = collectedData.map(_(feature)).toSet
- val targetValueIndexSet = Range(0, origValueSet.size).toSet
- val catMap = categoryMaps(feature)
- assert(catMap.keys.toSet === origValueSet) // Correct categories
- assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices
- if (origValueSet.contains(0.0)) {
- assert(catMap(0.0) === 0) // value 0 gets index 0
- }
- // Check transformed data
- assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet)
- // Check metadata
- val featureAttr = featureAttrs(feature)
- assert(featureAttr.index.get === feature)
- featureAttr match {
- case attr: BinaryAttribute =>
- assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
- case attr: NominalAttribute =>
- assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
- assert(attr.isOrdinal.get === false)
- case _ =>
- throw new RuntimeException(errMsg + s". Categorical feature $feature failed" +
- s" metadata check. Found feature attribute: $featureAttr.")
+ testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows =>
+ val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed")
+ val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0))
+ val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed"))
+ assert(featureAttrs.name === "indexed")
+ assert(featureAttrs.attributes.get.length === model.numFeatures)
+ categoricalFeatures.foreach { feature: Int =>
+ val origValueSet = collectedData.map(_(feature)).toSet
+ val targetValueIndexSet = Range(0, origValueSet.size).toSet
+ val catMap = categoryMaps(feature)
+ assert(catMap.keys.toSet === origValueSet) // Correct categories
+ assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices
+ if (origValueSet.contains(0.0)) {
+ assert(catMap(0.0) === 0) // value 0 gets index 0
+ }
+ // Check transformed data
+ assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet)
+ // Check metadata
+ val featureAttr = featureAttrs(feature)
+ assert(featureAttr.index.get === feature)
+ featureAttr match {
+ case attr: BinaryAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ case attr: NominalAttribute =>
+ assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString))
+ assert(attr.isOrdinal.get === false)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Categorical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
}
- }
- // Check numerical feature metadata.
- Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature))
- .foreach { feature: Int =>
- val featureAttr = featureAttrs(feature)
- featureAttr match {
- case attr: NumericAttribute =>
- assert(featureAttr.index.get === feature)
- case _ =>
- throw new RuntimeException(errMsg + s". Numerical feature $feature failed" +
- s" metadata check. Found feature attribute: $featureAttr.")
+ // Check numerical feature metadata.
+ Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature))
+ .foreach { feature: Int =>
+ val featureAttr = featureAttrs(feature)
+ featureAttr match {
+ case attr: NumericAttribute =>
+ assert(featureAttr.index.get === feature)
+ case _ =>
+ throw new RuntimeException(errMsg + s". Numerical feature $feature failed" +
+ s" metadata check. Found feature attribute: $featureAttr.")
+ }
}
}
} catch {
@@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
(sparsePoints1, sparsePoints1TestInvalid))) {
val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error")
val model = vectorIndexer.fit(points)
- intercept[SparkException] {
- model.transform(pointsTestInvalid).collect()
- }
+ testTransformerByInterceptingException[FeatureData](
+ pointsTestInvalid,
+ model,
+ "VectorIndexer encountered invalid value",
+ "indexed")
val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip")
val model1 = vectorIndexer1.fit(points)
- val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed")
- .collect().map(_(0))
- val transformed1 = model1.transform(points).select("indexed").collect().map(_(0))
- assert(transformed1 === invalidTransformed1)
-
+ val expected = Seq(
+ Vectors.dense(1.0, 2.0, 0.0),
+ Vectors.dense(0.0, 1.0, 2.0),
+ Vectors.dense(0.0, 0.0, 1.0),
+ Vectors.dense(1.0, 3.0, 2.0))
+ testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows =>
+ assert(rows.map(_(0)) == expected)
+ }
+ testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows =>
+ assert(rows.map(_(0)) == expected)
+ }
val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep")
val model2 = vectorIndexer2.fit(points)
- val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed")
- .collect().map(_(0))
- assert(invalidTransformed2 === transformed1 ++ Array(
- Vectors.dense(2.0, 2.0, 0.0),
- Vectors.dense(0.0, 4.0, 2.0),
- Vectors.dense(1.0, 3.0, 3.0))
- )
+ testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows =>
+ assert(rows.map(_(0)) == expected ++ Array(
+ Vectors.dense(2.0, 2.0, 0.0),
+ Vectors dense(0.0, 4.0, 2.0),
+ Vectors.dense(1.0, 3.0, 3.0)))
+ }
}
}
@@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
val points = data.collect().map(_.getAs[Vector](0))
val vectorIndexer = getIndexer.setMaxCategories(maxCategories)
val model = vectorIndexer.fit(data)
- val indexedPoints =
- model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect()
- points.zip(indexedPoints).foreach {
- case (orig: SparseVector, indexed: SparseVector) =>
- assert(orig.indices.length == indexed.indices.length)
- case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen
+ testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows =>
+ points.zip(rows.map(_(0))).foreach {
+ case (orig: SparseVector, indexed: SparseVector) =>
+ assert(orig.indices.length == indexed.indices.length)
+ case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen
+ }
}
}
checkSparsity(sparsePoints1, maxCategories = 2)
@@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
val vectorIndexer = getIndexer.setMaxCategories(2)
val model = vectorIndexer.fit(densePoints1WithMeta)
// Check that ML metadata are preserved.
- val indexedPoints = model.transform(densePoints1WithMeta)
- val transAttributes: Array[Attribute] =
- AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get
- featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
- assert(orig.name === trans.name)
- (orig, trans) match {
- case (orig: NumericAttribute, trans: NumericAttribute) =>
- assert(orig.max.nonEmpty && orig.max === trans.max)
- case _ =>
+ testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows =>
+ val transAttributes: Array[Attribute] =
+ AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get
+ featureAttributes.zip(transAttributes).foreach { case (orig, trans) =>
+ assert(orig.name === trans.name)
+ (orig, trans) match {
+ case (orig: NumericAttribute, trans: NumericAttribute) =>
+ assert(orig.max.nonEmpty && orig.max === trans.max)
+ case _ =>
// do nothing
// TODO: Once input features marked as categorical are handled correctly, check that here.
+ }
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
index f6c9a76599fae..d89d10b320d84 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest
class VectorSizeHintSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -40,16 +38,23 @@ class VectorSizeHintSuite
val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue")
val noSizeTransformer = new VectorSizeHint().setInputCol("vector")
- intercept[NoSuchElementException] (noSizeTransformer.transform(data))
+ testTransformerByInterceptingException[(Vector, Int)](
+ data,
+ noSizeTransformer,
+ "Failed to find a default value for size",
+ "vector")
intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema))
val noInputColTransformer = new VectorSizeHint().setSize(2)
- intercept[NoSuchElementException] (noInputColTransformer.transform(data))
+ testTransformerByInterceptingException[(Vector, Int)](
+ data,
+ noInputColTransformer,
+ "Failed to find a default value for inputCol",
+ "vector")
intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema))
}
test("Adding size to column of vectors.") {
-
val size = 3
val vectorColName = "vector"
val denseVector = Vectors.dense(1, 2, 3)
@@ -66,12 +71,15 @@ class VectorSizeHintSuite
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
- val withSize = transformer.transform(dataFrame)
- assert(
- AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
- "Transformer did not add expected size data.")
- val numRows = withSize.collect().length
- assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) {
+ rows => {
+ assert(
+ AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size,
+ "Transformer did not add expected size data.")
+ val numRows = rows.length
+ assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.")
+ }
+ }
}
}
@@ -93,14 +101,16 @@ class VectorSizeHintSuite
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
- val withSize = transformer.transform(dataFrameWithMetadata)
-
- val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName))
- assert(newGroup.size === size, "Column has incorrect size metadata.")
- assert(
- newGroup.attributes.get === group.attributes.get,
- "VectorSizeHint did not preserve attributes.")
- withSize.collect
+ testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)](
+ dataFrameWithMetadata,
+ transformer,
+ vectorColName) { rows =>
+ val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName))
+ assert(newGroup.size === size, "Column has incorrect size metadata.")
+ assert(
+ newGroup.attributes.get === group.attributes.get,
+ "VectorSizeHint did not preserve attributes.")
+ }
}
}
@@ -120,7 +130,11 @@ class VectorSizeHintSuite
.setInputCol(vectorColName)
.setSize(size)
.setHandleInvalid(handleInvalid)
- intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata))
+ testTransformerByInterceptingException[(Int, Int, Int, Vector)](
+ dataFrameWithMetadata,
+ transformer,
+ "Trying to set size of vectors in `vector` to 4 but size already set to 3.",
+ vectorColName)
}
}
@@ -136,18 +150,36 @@ class VectorSizeHintSuite
.setHandleInvalid("error")
.setSize(3)
- intercept[SparkException](sizeHint.transform(dataWithNull).collect())
- intercept[SparkException](sizeHint.transform(dataWithShort).collect())
+ testTransformerByInterceptingException[Tuple1[Vector]](
+ dataWithNull,
+ sizeHint,
+ "Got null vector in VectorSizeHint",
+ "vector")
+
+ testTransformerByInterceptingException[Tuple1[Vector]](
+ dataWithShort,
+ sizeHint,
+ "VectorSizeHint Expecting a vector of size 3 but got 1",
+ "vector")
sizeHint.setHandleInvalid("skip")
- assert(sizeHint.transform(dataWithNull).count() === 1)
- assert(sizeHint.transform(dataWithShort).count() === 1)
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows =>
+ assert(rows.length === 1)
+ }
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows =>
+ assert(rows.length === 1)
+ }
sizeHint.setHandleInvalid("optimistic")
- assert(sizeHint.transform(dataWithNull).count() === 2)
- assert(sizeHint.transform(dataWithShort).count() === 2)
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows =>
+ assert(rows.length === 2)
+ }
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows =>
+ assert(rows.length === 2)
+ }
}
+
test("read/write") {
val sizeHint = new VectorSizeHint()
.setInputCol("myInputCol")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 1746ce53107c4..3d90f9d9ac764 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -17,16 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField, StructType}
-class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class VectorSlicerSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
test("params") {
val slicer = new VectorSlicer().setInputCol("feature")
@@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
- def validateResults(df: DataFrame): Unit = {
- df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
+ def validateResults(rows: Seq[Row]): Unit = {
+ rows.foreach { case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 === vec2)
}
- val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
- val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
+ val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result"))
+ val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected"))
assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
assert(a === b)
@@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
}
vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
- validateResults(vectorSlicer.transform(df))
+ testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")(
+ validateResults)
vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
- validateResults(vectorSlicer.transform(df))
+ testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")(
+ validateResults)
vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
- validateResults(vectorSlicer.transform(df))
+ testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")(
+ validateResults)
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 10682ba176aca..b59c4e7967338 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -17,17 +17,17 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.util.Utils
-class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class Word2VecSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
test("params") {
ParamsSuite.checkParams(new Word2Vec)
@@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("Word2Vec") {
-
- val spark = this.spark
- import spark.implicits._
-
val sentence = "a b " * 100 + "a c " * 10
val numOfWords = sentence.split(" ").size
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
// These expectations are just magic values, characterizing the current
// behavior. The test needs to be updated to be more general, see SPARK-11502
val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167)
- model.transform(docDF).select("result", "expected").collect().foreach {
+ testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.")
}
}
test("getVectors") {
-
- val spark = this.spark
- import spark.implicits._
-
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
@@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") {
- val spark = this.spark
- import spark.implicits._
-
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val docDF = doc.zip(doc).toDF("text", "alsotext")
@@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("window size") {
- val spark = this.spark
- import spark.implicits._
-
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val docDF = doc.zip(doc).toDF("text", "alsotext")
@@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("Word2Vec works with input that is non-nullable (NGram)") {
- val spark = this.spark
- import spark.implicits._
val sentence = "a q s t q s t b b b s t m s t m q "
val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
@@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.fit(ngramDF)
// Just test that this transformation succeeds
- model.transform(ngramDF).collect()
+ testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
new file mode 100644
index 0000000000000..2252151af306b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.fpm
+
+import org.apache.spark.ml.util.MLTest
+import org.apache.spark.sql.DataFrame
+
+class PrefixSpanSuite extends MLTest {
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ }
+
+ test("PrefixSpan projections with multiple partial starts") {
+ val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
+ val result = new PrefixSpan()
+ .setMinSupport(1.0)
+ .setMaxPatternLength(2)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(smallDataset)
+ .as[(Seq[Seq[Int]], Long)].collect()
+ val expected = Array(
+ (Seq(Seq(1)), 1L),
+ (Seq(Seq(1, 2)), 1L),
+ (Seq(Seq(1), Seq(1)), 1L),
+ (Seq(Seq(1), Seq(2)), 1L),
+ (Seq(Seq(1), Seq(3)), 1L),
+ (Seq(Seq(1, 3)), 1L),
+ (Seq(Seq(2)), 1L),
+ (Seq(Seq(2, 3)), 1L),
+ (Seq(Seq(2), Seq(1)), 1L),
+ (Seq(Seq(2), Seq(2)), 1L),
+ (Seq(Seq(2), Seq(3)), 1L),
+ (Seq(Seq(3)), 1L))
+ compareResults[Int](expected, result)
+ }
+
+ /*
+ To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content
+ (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
+ 1 1 2 1 2
+ 1 2 1 3
+ 2 1 1 1
+ 2 2 2 3 2
+ 2 3 2 1 2
+ 3 1 2 1 2
+ 3 2 1 5
+ 4 1 1 6
+ In R, run:
+ library("arulesSequences")
+ prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
+ freqItemSeq = cspade(prefixSpanSeqs,
+ parameter = 0.5, maxlen = 5 ))
+ resSeq = as(freqItemSeq, "data.frame")
+ resSeq
+
+ sequence support
+ 1 <{1}> 0.75
+ 2 <{2}> 0.75
+ 3 <{3}> 0.50
+ 4 <{1},{3}> 0.50
+ 5 <{1,2}> 0.75
+ */
+ val smallTestData = Seq(
+ Seq(Seq(1, 2), Seq(3)),
+ Seq(Seq(1), Seq(3, 2), Seq(1, 2)),
+ Seq(Seq(1, 2), Seq(5)),
+ Seq(Seq(6)))
+
+ val smallTestDataExpectedResult = Array(
+ (Seq(Seq(1)), 3L),
+ (Seq(Seq(2)), 3L),
+ (Seq(Seq(3)), 2L),
+ (Seq(Seq(1), Seq(3)), 2L),
+ (Seq(Seq(1, 2)), 3L)
+ )
+
+ test("PrefixSpan Integer type, variable-size itemsets") {
+ val df = smallTestData.toDF("sequence")
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
+ .as[(Seq[Seq[Int]], Long)].collect()
+
+ compareResults[Int](smallTestDataExpectedResult, result)
+ }
+
+ test("PrefixSpan input row with nulls") {
+ val df = (smallTestData :+ null).toDF("sequence")
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
+ .as[(Seq[Seq[Int]], Long)].collect()
+
+ compareResults[Int](smallTestDataExpectedResult, result)
+ }
+
+ test("PrefixSpan String type, variable-size itemsets") {
+ val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap
+ val df = smallTestData
+ .map(seq => seq.map(itemSet => itemSet.map(intToString)))
+ .toDF("sequence")
+ val result = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ .setMaxLocalProjDBSize(32000000)
+ .findFrequentSequentialPatterns(df)
+ .as[(Seq[Seq[String]], Long)].collect()
+
+ val expected = smallTestDataExpectedResult.map { case (seq, freq) =>
+ (seq.map(itemSet => itemSet.map(intToString)), freq)
+ }
+ compareResults[String](expected, result)
+ }
+
+ private def compareResults[Item](
+ expectedValue: Array[(Seq[Seq[Item]], Long)],
+ actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = {
+ val expectedSet = expectedValue.map { x =>
+ (x._1.map(_.toSet), x._2)
+ }.toSet
+ val actualSet = actualValue.map { x =>
+ (x._1.map(_.toSet), x._2)
+ }.toSet
+ assert(expectedSet === actualSet)
+ }
+}
+
diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
index a8833c615865d..527b3f8955968 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
@@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(count50 > 0 && count50 < countTotal)
}
+ test("readImages test: recursive = false") {
+ val df = readImages(imagePath, null, false, 3, true, 1.0, 0)
+ assert(df.count() === 0)
+ }
+
+ test("readImages test: read jpg image") {
+ val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0)
+ assert(df.count() === 1)
+ }
+
+ test("readImages test: read png image") {
+ val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0)
+ assert(df.count() === 1)
+ }
+
+ test("readImages test: read non image") {
+ val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0)
+ assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema")
+ assert(df.count() === 0)
+ }
+
+ test("readImages test: read non image and dropImageFailures is false") {
+ val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0)
+ assert(df.count() === 1)
+ }
+
+ test("readImages test: sampleRatio > 1") {
+ val e = intercept[IllegalArgumentException] {
+ readImages(imagePath, null, true, 3, true, 1.1, 0)
+ }
+ assert(e.getMessage.contains("sampleRatio"))
+ }
+
+ test("readImages test: sampleRatio < 0") {
+ val e = intercept[IllegalArgumentException] {
+ readImages(imagePath, null, true, 3, true, -0.1, 0)
+ }
+ assert(e.getMessage.contains("sampleRatio"))
+ }
+
+ test("readImages test: sampleRatio = 0") {
+ val df = readImages(imagePath, null, true, 3, true, 0.0, 0)
+ assert(df.count() === 0)
+ }
+
+ test("readImages test: with sparkSession") {
+ val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0)
+ assert(df.count() === 8)
+ }
+
test("readImages partition test") {
val df = readImages(imagePath, null, true, 3, true, 1.0, 0)
assert(df.rdd.getNumPartitions === 3)
}
+ test("readImages partition test: < 0") {
+ val df = readImages(imagePath, null, true, -3, true, 1.0, 0)
+ assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism)
+ }
+
+ test("readImages partition test: = 0") {
+ val df = readImages(imagePath, null, true, 0, true, 1.0, 0)
+ assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism)
+ }
+
// Images with the different number of channels
test("readImages pixel values test") {
@@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
// - default representation for 3-channel RGB images is BGR row-wise:
// (B00, G00, R00, B10, G10, R10, ...)
// - default representation for 4-channel RGB images is BGRA row-wise:
- // (B00, G00, R00, A00, B10, G10, R10, A00, ...)
+ // (B00, G00, R00, A00, B10, G10, R10, A10, ...)
private val firstBytes20 = Map(
"grayscale.jpg" ->
(("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index addcd21d50aac..e3dfe2faf5698 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -22,8 +22,7 @@ import java.util.Random
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.WrappedArray
+import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -35,21 +34,20 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.recommendation.ALS._
-import org.apache.spark.ml.recommendation.ALS.Rating
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
-import org.apache.spark.sql.{DataFrame, Row, SparkSession}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class ALSSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
+class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
override def beforeAll(): Unit = {
super.beforeAll()
@@ -413,34 +411,36 @@ class ALSSuite
.setSeed(0)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
- val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
- case Row(rating: Float, prediction: Float) =>
- (rating.toDouble, prediction.toDouble)
+ testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") {
+ case rows: Seq[Row] =>
+ val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble))
+
+ val rmse =
+ if (implicitPrefs) {
+ // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
+ // We limit the ratings and the predictions to interval [0, 1] and compute the
+ // weighted RMSE with the confidence scores as weights.
+ val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
+ val confidence = 1.0 + alpha * math.abs(rating)
+ val rating01 = math.max(math.min(rating, 1.0), 0.0)
+ val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
+ val err = prediction01 - rating01
+ (confidence, confidence * err * err)
+ }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) =>
+ (c0 + c1, e0 + e1)
+ }
+ math.sqrt(weightedSumSq / totalWeight)
+ } else {
+ val errorSquares = predictions.map { case (rating, prediction) =>
+ val err = rating - prediction
+ err * err
+ }
+ val mse = errorSquares.sum / errorSquares.length
+ math.sqrt(mse)
+ }
+ logInfo(s"Test RMSE is $rmse.")
+ assert(rmse < targetRMSE)
}
- val rmse =
- if (implicitPrefs) {
- // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
- // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE
- // with the confidence scores as weights.
- val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
- val confidence = 1.0 + alpha * math.abs(rating)
- val rating01 = math.max(math.min(rating, 1.0), 0.0)
- val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
- val err = prediction01 - rating01
- (confidence, confidence * err * err)
- }.reduce { case ((c0, e0), (c1, e1)) =>
- (c0 + c1, e0 + e1)
- }
- math.sqrt(weightedSumSq / totalWeight)
- } else {
- val mse = predictions.map { case (rating, prediction) =>
- val err = rating - prediction
- err * err
- }.mean()
- math.sqrt(mse)
- }
- logInfo(s"Test RMSE is $rmse.")
- assert(rmse < targetRMSE)
MLTestingUtils.checkCopyAndUids(als, model)
}
@@ -586,6 +586,68 @@ class ALSSuite
allModelParamSettings, checkModelData)
}
+ private def checkNumericTypesALS(
+ estimator: ALS,
+ spark: SparkSession,
+ column: String,
+ baseType: NumericType)
+ (check: (ALSModel, ALSModel) => Unit)
+ (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = {
+ val dfs = genRatingsDFWithNumericCols(spark, column)
+ val df = dfs.find {
+ case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType
+ } match {
+ case Some((_, df)) => df
+ }
+ val expected = estimator.fit(df)
+ val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2)))
+ actuals.foreach { case (_, actual) => check(expected, actual) }
+ actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) }
+
+ val baseDF = dfs.find(_._1.numericType == baseType).get._2
+ val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
+ val cols = Seq(col(column).cast(StringType)) ++ others
+ val strDF = baseDF.select(cols: _*)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(strDF)
+ }
+ assert(thrown.getMessage.contains(
+ s"$column must be of type NumericType but was actually of type StringType"))
+ }
+
+ private class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Int, Double)])
+
+ private def genRatingsDFWithNumericCols(
+ spark: SparkSession,
+ column: String) = {
+
+ import testImplicits._
+
+ val df = spark.createDataFrame(Seq(
+ (0, 10, 1.0),
+ (1, 20, 2.0),
+ (2, 30, 3.0),
+ (3, 40, 4.0),
+ (4, 50, 5.0)
+ )).toDF("user", "item", "rating")
+
+ val others = df.columns.toSeq.diff(Seq(column)).map(col)
+ val types =
+ Seq(new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())
+ )
+ types.map { t =>
+ val cols = Seq(col(column).cast(t.numericType)) ++ others
+ t -> df.select(cols: _*)
+ }
+ }
+
test("input type validation") {
val spark = this.spark
import spark.implicits._
@@ -595,12 +657,16 @@ class ALSSuite
val als = new ALS().setMaxIter(1).setRank(1)
Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
case (colName, sqlType) =>
- MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+ checkNumericTypesALS(als, spark, colName, sqlType) {
(ex, act) =>
- ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
- } { (ex, act, _) =>
- ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
- act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
+ ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1)
+ } { (ex, act, df, enc) =>
+ val expected = ex.transform(df).selectExpr("prediction")
+ .first().getFloat(0)
+ testTransformerByGlobalCheckFunc(df, act, "prediction") {
+ case rows: Seq[Row] =>
+ expected ~== rows.head.getFloat(0) absTol 1e-6
+ }(enc)
}
}
// check user/item ids falling outside of Int range
@@ -628,18 +694,22 @@ class ALSSuite
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
- assert(intercept[SparkException] {
- model.transform(df.select(df("user_big").as("user"), df("item"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("user_small").as("user"), df("item"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("item_big").as("item"), df("user"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("item_small").as("item"), df("user"))).first
- }.getMessage.contains(msg))
+ def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
+ assert(intercept[SparkException] {
+ model.transform(dataFrame).first
+ }.getMessage.contains(msg))
+ assert(intercept[StreamingQueryException] {
+ testTransformer[A](dataFrame, model, "prediction") { _ => }
+ }.getMessage.contains(msg))
+ }
+ testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
+ df("item")))
+ testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"),
+ df("item")))
+ testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"),
+ df("user")))
+ testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"),
+ df("user")))
}
}
@@ -662,28 +732,31 @@ class ALSSuite
val knownItem = data.select(max("item")).as[Int].first()
val unknownItem = knownItem + 20
val test = Seq(
- (unknownUser, unknownItem),
- (knownUser, unknownItem),
- (unknownUser, knownItem),
- (knownUser, knownItem)
- ).toDF("user", "item")
+ (unknownUser, unknownItem, true),
+ (knownUser, unknownItem, true),
+ (unknownUser, knownItem, true),
+ (knownUser, knownItem, false)
+ ).toDF("user", "item", "expectedIsNaN")
val als = new ALS().setMaxIter(1).setRank(1)
// default is 'nan'
val defaultModel = als.fit(data)
- val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
- assert(defaultPredictions.length == 4)
- assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
- assert(!defaultPredictions.last.isNaN)
+ testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") {
+ case Row(expectedIsNaN: Boolean, prediction: Float) =>
+ assert(prediction.isNaN === expectedIsNaN)
+ }
// check 'drop' strategy should filter out rows with unknown users/items
- val dropPredictions = defaultModel
- .setColdStartStrategy("drop")
- .transform(test)
- .select("prediction").as[Float].collect()
- assert(dropPredictions.length == 1)
- assert(!dropPredictions.head.isNaN)
- assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+ val defaultPrediction = defaultModel.transform(test).select("prediction")
+ .as[Float].filter(!_.isNaN).first()
+ testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test,
+ defaultModel.setColdStartStrategy("drop"), "prediction") {
+ case rows: Seq[Row] =>
+ val dropPredictions = rows.map(_.getFloat(0))
+ assert(dropPredictions.length == 1)
+ assert(!dropPredictions.head.isNaN)
+ assert(dropPredictions.head ~== defaultPrediction relTol 1e-14)
+ }
}
test("case insensitive cold start param value") {
@@ -693,7 +766,7 @@ class ALSSuite
val data = ratings.toDF
val model = new ALS().fit(data)
Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
- model.setColdStartStrategy(s).transform(data)
+ testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 68a1218c23ece..9ae27339b11d5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("prediction on single instance") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val model = dt.fit(df)
+ testPredictionModelSinglePrediction(model, df)
+ }
+
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
@@ -176,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
+
+ test("label/impurity stats") {
+ val categoricalFeatures = Map(0 -> 2, 1 -> 2)
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val dtr = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(8)
+ val model = dtr.fit(df)
+ val statInfo = model.rootNode
+
+ assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0
+ && statInfo.getSumOfSquares == 600.0)
+ }
}
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 11c593b521e65..b145c7a3dc952 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -20,13 +20,15 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.util.Utils
/**
@@ -99,6 +101,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("prediction on single instance") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxIter(2)
+ val model = gbt.fit(trainData.toDF())
+ testPredictionModelSinglePrediction(model, validationData.toDF)
+ }
+
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -193,9 +203,81 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
+ test("model evaluateEachIteration") {
+ for (lossType <- GBTRegressor.supportedLossTypes) {
+ val gbt = new GBTRegressor()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType(lossType)
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTRegressionModel("gbt-reg-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures)
+ val model2 = new GBTRegressionModel("gbt-reg-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures)
+ for (evalLossType <- GBTRegressor.supportedLossTypes) {
+ val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
+ val lossErr1 = GradientBoostedTrees.computeError(validationData,
+ model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
+ val lossErr2 = GradientBoostedTrees.computeError(validationData,
+ model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
+ val lossErr3 = GradientBoostedTrees.computeError(validationData,
+ model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
- /////////////////////////////////////////////////////////////////////////////
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+ }
+ }
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val validationIndicatorCol = "validationIndicator"
+ val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false))
+ val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true))
+
+ val numIter = 20
+ for (lossType <- GBTRegressor.supportedLossTypes) {
+ val gbt = new GBTRegressor()
+ .setSeed(123)
+ .setMaxDepth(2)
+ .setLossType(lossType)
+ .setMaxIter(numIter)
+ val modelWithoutValidation = gbt.fit(trainDF)
+
+ gbt.setValidationIndicatorCol(validationIndicatorCol)
+ val modelWithValidation = gbt.fit(trainDF.union(validationDF))
+
+ assert(modelWithoutValidation.numTrees === numIter)
+ // early stop
+ assert(modelWithValidation.numTrees < numIter)
+
+ val errorWithoutValidation = GradientBoostedTrees.computeError(validationData,
+ modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
+ modelWithoutValidation.getOldLossType)
+ val errorWithValidation = GradientBoostedTrees.computeError(validationData,
+ modelWithValidation.trees, modelWithValidation.treeWeights,
+ modelWithValidation.getOldLossType)
+
+ assert(errorWithValidation < errorWithoutValidation)
+
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
+ OldAlgo.Regression)
+ assert(evaluationArray.length === numIter)
+ assert(evaluationArray(modelWithValidation.numTrees) >
+ evaluationArray(modelWithValidation.numTrees - 1))
+ var i = 1
+ while (i < modelWithValidation.numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ef2ff94a5e213..997c50157dcda 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
assert(model.getLink === "identity")
}
+ test("prediction on single instance") {
+ val glr = new GeneralizedLinearRegression
+ val model = glr.setFamily("gaussian").setLink("identity")
+ .fit(datasetGaussianIdentity)
+
+ testPredictionModelSinglePrediction(model, datasetGaussianIdentity)
+ }
+
test("generalized linear regression: gaussian family against glm") {
/*
R code:
@@ -485,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
}
[1] -0.0457441 -0.6833928
[1] 1.8121235 -0.1747493 -0.5815417
+
+ R code for deivance calculation:
+ data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1))
+ summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance
+ [1] 3.70055
+ summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance
+ [1] 3.809296
*/
val expected = Seq(
Vectors.dense(0.0, -0.0457441, -0.6833928),
Vectors.dense(1.8121235, -0.1747493, -0.5815417))
+ val residualDeviancesR = Array(3.809296, 3.70055)
+
import GeneralizedLinearRegression._
var idx = 0
@@ -502,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " +
s"$link link and fitIntercept = $fitIntercept (with zero values).")
+ assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3)
idx += 1
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index d42cb1714478f..90ceb7dee38f7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,18 +17,23 @@
package org.apache.spark.ml.regression
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Random
+import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel}
+
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.{DataFrame, Row}
-class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
+
+class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
import testImplicits._
@@ -636,6 +641,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("prediction on single instance") {
+ val trainer = new LinearRegression
+ val model = trainer.fit(datasetWithDenseFeature)
+
+ testPredictionModelSinglePrediction(model, datasetWithDenseFeature)
+ }
+
test("linear regression model with constant label") {
/*
R code:
@@ -1045,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
LinearRegressionSuite.allParamSettings, checkModelData)
}
+ test("pmml export") {
+ val lr = new LinearRegression()
+ val model = lr.fit(datasetWithWeight)
+ def checkModel(pmml: PMML): Unit = {
+ val dd = pmml.getDataDictionary
+ assert(dd.getNumberOfFields === 3)
+ val fields = dd.getDataFields.asScala
+ assert(fields(0).getName().toString === "field_0")
+ assert(fields(0).getOpType() == OpType.CONTINUOUS)
+ val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+ val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+ val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList
+ assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+ assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+ }
+ testPMMLWrite(sc, model, checkModel)
+ }
+
test("should support all NumericType labels and weights, and not support other types") {
for (solver <- Seq("auto", "l-bfgs", "normal")) {
val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 8b8e8a655f47b..e83c49f932973 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,22 +19,22 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest{
+class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
+ import testImplicits._
private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
@@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
+ test("prediction on single instance") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+
+ val df = orderedLabeledPoints50_1000.toDF()
+ val model = rf.fit(df)
+ testPredictionModelSinglePrediction(model, df)
+ }
+
test("Feature importance with toy data") {
val rf = new RandomForestRegressor()
.setImpurity("variance")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala
new file mode 100644
index 0000000000000..1312de3a1b522
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution,
+ RealDistribution, UniformRealDistribution}
+import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+
+class KolmogorovSmirnovTestSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ def apacheCommonMath3EquivalenceTest(
+ sampleDist: RealDistribution,
+ theoreticalDist: RealDistribution,
+ theoreticalDistByName: (String, Array[Double]),
+ rejectNullHypothesis: Boolean): Unit = {
+
+ // set seeds
+ val seed = 10L
+ sampleDist.reseedRandomGenerator(seed)
+ if (theoreticalDist != null) {
+ theoreticalDist.reseedRandomGenerator(seed)
+ }
+
+ // Sample data from the distributions and parallelize it
+ val n = 100000
+ val sampledArray = sampleDist.sample(n)
+ val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample")
+
+ // Use a apache math commons local KS test to verify calculations
+ val ksTest = new Math3KSTest()
+ val pThreshold = 0.05
+
+ // Comparing a standard normal sample to a standard normal distribution
+ val Row(pValue1: Double, statistic1: Double) =
+ if (theoreticalDist != null) {
+ val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x)
+ KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head()
+ } else {
+ KolmogorovSmirnovTest.test(sampledDF, "sample",
+ theoreticalDistByName._1,
+ theoreticalDistByName._2: _*
+ ).head()
+ }
+ val theoreticalDistMath3 = if (theoreticalDist == null) {
+ assert(theoreticalDistByName._1 == "norm")
+ val params = theoreticalDistByName._2
+ new NormalDistribution(params(0), params(1))
+ } else {
+ theoreticalDist
+ }
+ val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray)
+ val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n)
+ // Verify vs apache math commons ks test
+ assert(statistic1 ~== referenceStat1 relTol 1e-4)
+ assert(pValue1 ~== referencePVal1 relTol 1e-4)
+
+ if (rejectNullHypothesis) {
+ assert(pValue1 < pThreshold)
+ } else {
+ assert(pValue1 > pThreshold)
+ }
+ }
+
+ test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") {
+ // Create theoretical distributions
+ val stdNormalDist = new NormalDistribution(0.0, 1.0)
+ val expDist = new ExponentialDistribution(0.6)
+ val uniformDist = new UniformRealDistribution(0.0, 1.0)
+ val expDist2 = new ExponentialDistribution(0.2)
+ val stdNormByName = Tuple2("norm", Array(0.0, 1.0))
+
+ apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false)
+ apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true)
+ apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true)
+ apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true)
+ }
+
+ test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") {
+ /*
+ Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample
+ > sessionInfo()
+ R version 3.2.0 (2015-04-16)
+ Platform: x86_64-apple-darwin13.4.0 (64-bit)
+ > set.seed(20)
+ > v <- rnorm(20)
+ > v
+ [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612
+ [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222
+ [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264
+ [19] -0.08578219 0.38921440
+ > ks.test(v, pnorm, alternative = "two.sided")
+
+ One-sample Kolmogorov-Smirnov test
+
+ data: v
+ D = 0.18874, p-value = 0.4223
+ alternative hypothesis: two-sided
+ */
+
+ val rKSStat = 0.18874
+ val rKSPVal = 0.4223
+ val rData = sc.parallelize(
+ Array(
+ 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
+ -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
+ -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
+ -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
+ 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
+ )
+ ).toDF("sample")
+ val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest
+ .test(rData, "sample", "norm", 0, 1).head()
+ assert(statistic ~== rKSStat relTol 1e-4)
+ assert(pValue ~== rKSPVal relTol 1e-4)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dbe2ea931fb9c..4dbbd75d2466d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.tree.impl
+import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.SparkFunSuite
@@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
import RandomForestSuite.mapToVec
+ private val seed = 42
+
/////////////////////////////////////////////////////////////////////////////
// Tests for split calculation
/////////////////////////////////////////////////////////////////////////////
@@ -90,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("find splits for a continuous feature") {
// find splits for normal case
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
- val featureSamples = Array.fill(200000)(math.random)
+ val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5)
@@ -106,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// SPARK-16957: Use midpoints for split values.
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
@@ -114,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// possibleSplits <= numSplits
{
- val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
+ val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2)
assert(splits === expectedSplits)
@@ -122,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// possibleSplits > numSplits
{
- val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble)
+ val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
@@ -132,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
@@ -147,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
@@ -161,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the maximum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0,
Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
+ .map(_.toDouble).filter(_ != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2)
assert(splits === expectedSplits)
@@ -174,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits for constant feature
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
)
- val featureSamples = Array(0, 0, 0).map(_.toDouble)
+ val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0)
val featureSamplesEmpty = Array.empty[Double]
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array.empty[Double])
@@ -320,10 +324,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
+ val nodesForGroup = Map(0 -> Array(topNode))
+ val treeToNodeToIndexInfo = Map(0 -> Map(
+ topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+ ))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -336,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.stats.impurity > 0.0)
// set impurity and predict for child nodes
- assert(topNode.leftChild.get.toNode.prediction === 0.0)
- assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
+ assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
assert(topNode.leftChild.get.stats.impurity === 0.0)
assert(topNode.rightChild.get.stats.impurity === 0.0)
}
@@ -362,10 +366,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.isLeaf === false)
assert(topNode.stats === null)
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
+ val nodesForGroup = Map(0 -> Array(topNode))
+ val treeToNodeToIndexInfo = Map(0 -> Map(
+ topNode.id -> new RandomForest.NodeIndexInfo(0, None)
+ ))
val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode),
nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
@@ -378,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.stats.impurity > 0.0)
// set impurity and predict for child nodes
- assert(topNode.leftChild.get.toNode.prediction === 0.0)
- assert(topNode.rightChild.get.toNode.prediction === 1.0)
+ assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0)
+ assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0)
assert(topNode.leftChild.get.stats.impurity === 0.0)
assert(topNode.rightChild.get.stats.impurity === 0.0)
}
@@ -407,7 +411,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 42, instr = None).head
+ seed = 42, instr = None, prune = false).head
+
model.rootNode match {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
@@ -577,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
left right
*/
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
- val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+ val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp)
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
- val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+ val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp)
- val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
+ val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true)
val parentImp = parent.impurityStats
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
- val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+ val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp)
- val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
+ val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true)
val grandImp = grandParent.impurityStats
// Test feature importance computed at different subtrees.
@@ -613,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Forest consisting of (full tree) + (internal node with 2 leafs)
val trees = Array(parent, grandParent).map { root =>
- new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
- .asInstanceOf[DecisionTreeModel]
+ new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode],
+ numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel]
}
val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
@@ -631,13 +636,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+ ///////////////////////////////////////////////////////////////////////////////
+ // Tests for pruning of redundant subtrees (generated by a split improving the
+ // impurity measure, but always leading to the same prediction).
+ ///////////////////////////////////////////////////////////////////////////////
+
+ test("SPARK-3159 tree model redundancy - classification") {
+ // The following dataset is set up such that splitting over feature_1 for points having
+ // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
+ // in both branches.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+ )
+ val rdd = sc.parallelize(arr)
+
+ val numClasses = 2
+ val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4,
+ numClasses = numClasses, maxBins = 32)
+
+ val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
+ seed = 42, instr = None).head
+
+ val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
+ seed = 42, instr = None, prune = false).head
+
+ assert(prunedTree.numNodes === 5)
+ assert(unprunedTree.numNodes === 7)
+
+ assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
+ }
+
+ test("SPARK-3159 tree model redundancy - regression") {
+ // The following dataset is set up such that splitting over feature_0 for points having
+ // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
+ // in both branches.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
+ )
+ val rdd = sc.parallelize(arr)
+
+ val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4,
+ numClasses = 0, maxBins = 32)
+
+ val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
+ seed = 42, instr = None).head
+
+ val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto",
+ seed = 42, instr = None, prune = false).head
+
+ assert(prunedTree.numNodes === 3)
+ assert(unprunedTree.numNodes === 5)
+ assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
+ }
}
private object RandomForestSuite {
-
def mapToVec(map: Map[Int, Double]): Vector = {
val size = (map.keys.toSeq :+ 0).max + 1
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}
+
+ @tailrec
+ private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = {
+ if (nodes.isEmpty) {
+ acc
+ }
+ else {
+ nodes.head match {
+ case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
+ case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index b6894b30b0c2b..3f03d909d4a4c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* @param split Split for parent node
* @return Parent node with children attached
*/
- def buildParentNode(left: Node, right: Node, split: Split): Node = {
+ def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
@@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite {
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
- new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
+ if (isClassification) {
+ new ClassificationInternalNode(pred, parentImp.calculate(), gain,
+ left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode],
+ split, parentImp)
+ } else {
+ new RegressionInternalNode(pred, parentImp.calculate(), gain,
+ left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode],
+ split, parentImp)
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 15dade2627090..e6ee7220d2279 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
class CrossValidatorSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -66,6 +66,13 @@ class CrossValidatorSuite
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)
+
+ val result = cvModel.transform(dataset).select("prediction").as[Double].collect()
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") {
+ rows =>
+ val result2 = rows.map(_.getDouble(0))
+ assert(result === result2)
+ }
}
test("cross validation with linear regression") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 9024342d9c831..cd76acf9c67bc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
-import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
class TrainValidationSplitSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+ extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -64,6 +64,13 @@ class TrainValidationSplitSuite
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(tvsModel.validationMetrics.length === lrParamMaps.length)
+
+ val result = tvsModel.transform(dataset).select("prediction").as[Double].collect()
+ testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") {
+ rows =>
+ val result2 = rows.map(_.getDouble(0))
+ assert(result === result2)
+ }
}
test("train validation with linear regression") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 4da95e74434ee..4d9e664850c12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -19,9 +19,10 @@ package org.apache.spark.ml.util
import java.io.{File, IOException}
+import org.json4s.JNothing
import org.scalatest.Suite
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
class MyParams(override val uid: String) extends Params with MLWritable {
final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
+ final val shouldNotSetIfSetintParamWithDefault: IntParam =
+ new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc")
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
@@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable {
set(doubleArrayParam -> Array(8.0, 9.0))
set(stringArrayParam -> Array("10", "11"))
+ def checkExclusiveParams(): Unit = {
+ if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) {
+ throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " +
+ "shouldn't be set at the same time")
+ }
+ }
+
override def copy(extra: ParamMap): Params = defaultCopy(extra)
override def write: MLWriter = new DefaultParamsWriter(this)
@@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
val myParams = new MyParams("my_params")
testDefaultReadWrite(myParams)
}
+
+ test("default param shouldn't become user-supplied param after persistence") {
+ val myParams = new MyParams("my_params")
+ myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1)
+ myParams.checkExclusiveParams()
+ val loadedMyParams = testDefaultReadWrite(myParams)
+ loadedMyParams.checkExclusiveParams()
+ assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) ==
+ myParams.getDefault(myParams.intParamWithDefault))
+
+ loadedMyParams.set(myParams.intParamWithDefault, 1)
+ intercept[SparkException] {
+ loadedMyParams.checkExclusiveParams()
+ }
+ }
+
+ test("User-supplied value for default param should be kept after persistence") {
+ val myParams = new MyParams("my_params")
+ myParams.set(myParams.intParamWithDefault, 100)
+ val loadedMyParams = testDefaultReadWrite(myParams)
+ assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100)
+ }
+
+ test("Read metadata without default field prior to 2.4") {
+ // default params are saved in `paramMap` field in metadata file prior to Spark 2.4.
+ val metadata = """{"class":"org.apache.spark.ml.util.MyParams",
+ |"timestamp":1518852502761,"sparkVersion":"2.3.0",
+ |"uid":"my_params",
+ |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+ val parsedMetadata = DefaultParamsReader.parseMetadata(metadata)
+ val myParams = new MyParams("my_params")
+ assert(!myParams.isSet(myParams.intParamWithDefault))
+ parsedMetadata.getAndSetParams(myParams)
+
+ // The behavior prior to Spark 2.4, default params are set in loaded ML instance.
+ assert(myParams.isSet(myParams.intParamWithDefault))
+ }
+
+ test("Should raise error when read metadata without default field after Spark 2.4") {
+ val myParams = new MyParams("my_params")
+
+ val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams",
+ |"timestamp":1518852502761,"sparkVersion":"2.4.0",
+ |"uid":"my_params",
+ |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+ val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1)
+ val err1 = intercept[IllegalArgumentException] {
+ parsedMetadata1.getAndSetParams(myParams)
+ }
+ assert(err1.getMessage().contains("Cannot recognize JSON metadata"))
+
+ val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams",
+ |"timestamp":1518852502761,"sparkVersion":"3.0.0",
+ |"uid":"my_params",
+ |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+ val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2)
+ val err2 = intercept[IllegalArgumentException] {
+ parsedMetadata2.getAndSetParams(myParams)
+ }
+ assert(err2.getMessage().contains("Cannot recognize JSON metadata"))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 17678aa611a48..76d41f9b23715 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -22,9 +22,11 @@ import java.io.File
import org.scalatest.Suite
import org.apache.spark.SparkContext
-import org.apache.spark.ml.{PipelineModel, Transformer}
-import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.ml.{PredictionModel, Transformer}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.TestSparkSession
import org.apache.spark.util.Utils
@@ -62,8 +64,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
val columnNames = dataframe.schema.fieldNames
val stream = MemoryStream[A]
- val streamDF = stream.toDS().toDF(columnNames: _*)
-
+ val columnsWithMetadata = dataframe.schema.map { structField =>
+ col(structField.name).as(structField.name, structField.metadata)
+ }
+ val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*)
val data = dataframe.as[A].collect()
val streamOutput = transformer.transform(streamDF)
@@ -108,5 +112,39 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
otherResultCols: _*)(globalCheckFunction)
testTransformerOnDF(dataframe, transformer, firstResultCol,
otherResultCols: _*)(globalCheckFunction)
+ }
+
+ def testTransformerByInterceptingException[A : Encoder](
+ dataframe: DataFrame,
+ transformer: Transformer,
+ expectedMessagePart : String,
+ firstResultCol: String) {
+
+ def hasExpectedMessage(exception: Throwable): Boolean =
+ exception.getMessage.contains(expectedMessagePart) ||
+ (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart))
+
+ withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") {
+ val exceptionOnDf = intercept[Throwable] {
+ testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit)
+ }
+ assert(hasExpectedMessage(exceptionOnDf))
+ }
+ withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") {
+ val exceptionOnStreamData = intercept[Throwable] {
+ testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit)
+ }
+ assert(hasExpectedMessage(exceptionOnStreamData))
+ }
+ }
+
+ def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _],
+ dataset: Dataset[_]): Unit = {
+
+ model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol)
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double) =>
+ assert(prediction === model.predict(features))
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index aef81c8c173a0..5e72b4d864c1d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
@@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite {
}
}
- def checkNumericTypesALS(
- estimator: ALS,
- spark: SparkSession,
- column: String,
- baseType: NumericType)
- (check: (ALSModel, ALSModel) => Unit)
- (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
- val dfs = genRatingsDFWithNumericCols(spark, column)
- val expected = estimator.fit(dfs(baseType))
- val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
- actuals.foreach { case (_, actual) => check(expected, actual) }
- actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
-
- val baseDF = dfs(baseType)
- val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
- val cols = Seq(col(column).cast(StringType)) ++ others
- val strDF = baseDF.select(cols: _*)
- val thrown = intercept[IllegalArgumentException] {
- estimator.fit(strDF)
- }
- assert(thrown.getMessage.contains(
- s"$column must be of type NumericType but was actually of type StringType"))
- }
-
def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
@@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite {
}.toMap
}
- def genRatingsDFWithNumericCols(
- spark: SparkSession,
- column: String): Map[NumericType, DataFrame] = {
- val df = spark.createDataFrame(Seq(
- (0, 10, 1.0),
- (1, 20, 2.0),
- (2, 30, 3.0),
- (3, 40, 4.0),
- (4, 50, 5.0)
- )).toDF("user", "item", "rating")
-
- val others = df.columns.toSeq.diff(Seq(column)).map(col)
- val types: Seq[NumericType] =
- Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
- types.map { t =>
- val cols = Seq(col(column).cast(t)) ++ others
- t -> df.select(cols: _*)
- }.toMap
- }
-
def genEvaluatorDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",
@@ -291,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite {
}
models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)}
}
+
+ /**
+ * Helper function for testing different input types for "features" column. Given a DataFrame,
+ * generate three output DataFrames: one having vector "features" column with float precision,
+ * one having double array "features" column with float precision, and one having float array
+ * "features" column.
+ */
+ def generateArrayFeatureDataset(dataset: Dataset[_],
+ featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = {
+ val toFloatVectorUDF = udf { (features: Vector) =>
+ Vectors.dense(features.toArray.map(_.toFloat.toDouble))}
+ val toDoubleArrayUDF = udf { (features: Vector) => features.toArray}
+ val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)}
+ val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName)))
+ val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName)))
+ val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName)))
+ assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT))
+ assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false)))
+ assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
+ (newDataset, newDatasetD, newDatasetF)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala
new file mode 100644
index 0000000000000..d2c4832b12bac
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.{File, IOException}
+
+import org.dmg.pmml.PMML
+import org.scalatest.Suite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+trait PMMLReadWriteTest extends TempDirectory { self: Suite =>
+ /**
+ * Test PMML export. Requires exported model is small enough to be loaded locally.
+ * Checks that the model can be exported and the result is valid PMML, but does not check
+ * the specific contents of the model.
+ */
+ def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T,
+ checkModelData: PMML => Unit): Unit = {
+ val uid = instance.uid
+ val subdirName = Identifiable.randomUID("pmml-")
+
+ val subdir = new File(tempDir, subdirName)
+ val path = new File(subdir, uid).getPath
+
+ instance.write.format("pmml").save(path)
+ intercept[IOException] {
+ instance.write.format("pmml").save(path)
+ }
+ instance.write.format("pmml").overwrite().save(path)
+ val pmmlStr = sc.textFile(path).collect.mkString("\n")
+ val pmmlModel = PMMLUtils.loadFromString(pmmlStr)
+ assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark"))
+ checkModelData(pmmlModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
new file mode 100644
index 0000000000000..dbdc69f95d841
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.util
+
+import java.io.StringReader
+import javax.xml.bind.Unmarshaller
+import javax.xml.transform.Source
+
+import org.dmg.pmml._
+import org.jpmml.model.{ImportFilter, JAXBUtil}
+import org.xml.sax.InputSource
+
+/**
+ * Testing utils for working with PMML.
+ * Predictive Model Markup Language (PMML) is an XML-based file format
+ * developed by the Data Mining Group (www.dmg.org).
+ */
+private[spark] object PMMLUtils {
+ /**
+ * :: Experimental ::
+ * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported
+ * through external spark-packages.
+ */
+ def loadFromString(input: String): PMML = {
+ val is = new StringReader(input)
+ val transformed = ImportFilter.apply(new InputSource(is))
+ JAXBUtil.unmarshalPMML(transformed)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala
new file mode 100644
index 0000000000000..f4c1f0bdb32cd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.PipelineStage
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.sql.{DataFrame, SparkSession}
+
+class FakeLinearRegressionWriter extends MLWriterFormat {
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ throw new Exception(s"Fake writer doesn't writestart")
+ }
+}
+
+class FakeLinearRegressionWriterWithName extends MLFormatRegister {
+ override def format(): String = "fakeWithName"
+ override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ throw new Exception(s"Fake writer doesn't writestart")
+ }
+}
+
+
+class DuplicateLinearRegressionWriter1 extends MLFormatRegister {
+ override def format(): String = "dupe"
+ override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ throw new Exception(s"Duplicate writer shouldn't have been called")
+ }
+}
+
+class DuplicateLinearRegressionWriter2 extends MLFormatRegister {
+ override def format(): String = "dupe"
+ override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel"
+ override def write(path: String, sparkSession: SparkSession,
+ optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
+ throw new Exception(s"Duplicate writer shouldn't have been called")
+ }
+}
+
+class ReadWriteSuite extends MLTest {
+
+ import testImplicits._
+
+ private val seed: Int = 42
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ dataset = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0),
+ xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF()
+ }
+
+ test("unsupported/non existent export formats") {
+ val lr = new LinearRegression()
+ val model = lr.fit(dataset)
+ // Does not exist with a long class name
+ val thrownDNE = intercept[SparkException] {
+ model.write.format("com.holdenkarau.boop").save("boop")
+ }
+ assert(thrownDNE.getMessage().
+ contains("Could not load requested format"))
+
+ // Does not exist with a short name
+ val thrownDNEShort = intercept[SparkException] {
+ model.write.format("boop").save("boop")
+ }
+ assert(thrownDNEShort.getMessage().
+ contains("Could not load requested format"))
+
+ // Check with a valid class that is not a writer format.
+ val thrownInvalid = intercept[SparkException] {
+ model.write.format("org.apache.spark.SparkContext").save("boop2")
+ }
+ assert(thrownInvalid.getMessage()
+ .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat"))
+ }
+
+ test("invalid paths fail") {
+ val lr = new LinearRegression()
+ val model = lr.fit(dataset)
+ val thrown = intercept[Exception] {
+ model.write.format("pmml").save("")
+ }
+ assert(thrown.getMessage().contains("Can not create a Path from an empty string"))
+ }
+
+ test("dummy export format is called") {
+ val lr = new LinearRegression()
+ val model = lr.fit(dataset)
+ val thrown = intercept[Exception] {
+ model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name")
+ }
+ assert(thrown.getMessage().contains("Fake writer doesn't write"))
+ val thrownWithName = intercept[Exception] {
+ model.write.format("fakeWithName").save("name")
+ }
+ assert(thrownWithName.getMessage().contains("Fake writer doesn't write"))
+ }
+
+ test("duplicate format raises error") {
+ val lr = new LinearRegression()
+ val model = lr.fit(dataset)
+ val thrown = intercept[Exception] {
+ model.write.format("dupe").save("dupepanda")
+ }
+ assert(thrown.getMessage().contains("Multiple writers found for"))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 441d0f7614bf6..bc59f3f4125fb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
- LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
- LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite {
Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](3000)
for (i <- 0 until 3000) {
- if (i < 1000) {
+ if (i < 1001) {
arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
} else if (i < 2000) {
arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
diff --git a/pom.xml b/pom.xml
index 666d5d7169a15..4b4e6c13ea8fd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -129,8 +129,8 @@
1.2.110.12.1.1
- 1.8.2
- 1.4.1
+ 1.10.0
+ 1.4.4nohive1.6.09.3.20.v20170531
@@ -160,7 +160,7 @@
1.9.132.6.72.6.7.1
- 1.1.2.6
+ 1.1.7.11.1.21.2.0-incubating1.10
@@ -185,6 +185,10 @@
2.81.81.0.0
+
0.8.0${java.home}
@@ -575,7 +579,7 @@
commons-netcommons-net
- 2.2
+ 3.1io.netty
@@ -756,6 +760,12 @@
1.10.19test
+
+ org.jmock
+ jmock-junit4
+ test
+ 2.8.4
+ org.scalacheckscalacheck_${scala.binary.version}
@@ -1736,10 +1746,6 @@
org.apache.hivehive-storage-api
-
- io.airlift
- slice
-
@@ -1753,6 +1759,10 @@
org.apache.hadoophadoop-common
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+ org.apache.orcorc-core
@@ -1774,6 +1784,12 @@
parquet-hadoop${parquet.version}${parquet.deps.scope}
+
+
+ commons-pool
+ commons-pool
+
+ org.apache.parquet
@@ -2667,6 +2683,15 @@
+
+ hadoop-3.1
+
+ 3.1.0
+ 2.12.0
+ 3.4.9
+
+
+
yarn
@@ -2686,6 +2711,7 @@
kubernetesresource-managers/kubernetes/core
+ resource-managers/kubernetes/integration-tests
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d35c50e1d00fe..4f6d5ff898681 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,60 @@ object MimaExcludes {
// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
+ // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"),
+
+ // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher.
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"),
+
+ // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"),
+
+ // [SPARK-20659] Remove StorageStatus, or make it private
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"),
+
+ // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="),
+
+ // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"),
+
+ // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol")
)
// Exclude rules for 2.3.x
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 7469f11df0294..b606f9355e03b 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -27,6 +27,7 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
+import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._
import com.simplytyped.Antlr4Plugin._
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import com.typesafe.tools.mima.plugin.MimaKeys
@@ -56,11 +57,11 @@ object BuildCommons {
val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn,
streamingFlumeSink, streamingFlume,
streamingKafka, sparkGangliaLgpl, streamingKinesisAsl,
- dockerIntegrationTests, hadoopCloud) =
+ dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) =
Seq("kubernetes", "mesos", "yarn",
"streaming-flume-sink", "streaming-flume",
"streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl",
- "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _))
+ "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _))
val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) =
Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly")
@@ -317,7 +318,7 @@ object SparkBuild extends PomBuild {
/* Enable shared settings on all projects */
(allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools))
.foreach(enable(sharedSettings ++ DependencyOverrides.settings ++
- ExcludedDependencies.settings))
+ ExcludedDependencies.settings ++ Checkstyle.settings))
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
@@ -728,7 +729,8 @@ object Unidoc {
scalacOptions in (ScalaUnidoc, unidoc) ++= Seq(
"-groups", // Group similar methods together based on the @group annotation.
- "-skip-packages", "org.apache.hadoop"
+ "-skip-packages", "org.apache.hadoop",
+ "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath
) ++ (
// Add links to sources when generating Scaladoc for a non-snapshot release
if (!isSnapshot.value) {
@@ -740,6 +742,17 @@ object Unidoc {
)
}
+object Checkstyle {
+ lazy val settings = Seq(
+ checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error),
+ javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java",
+ javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java",
+ checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"),
+ checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml",
+ checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml"
+ )
+}
+
object CopyDependencies {
val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.")
diff --git a/project/build.properties b/project/build.properties
index b19518fd7aa1c..d03985d980ec8 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.16
+sbt.version=0.13.17
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 96bdb9067ae59..ffbd417b0f145 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,3 +1,11 @@
+addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1")
+
+// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's.
+libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2"
+
+// checkstyle uses guava 23.0.
+libraryDependencies += "com.google.guava" % "guava" % "23.0"
+
// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5"
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
diff --git a/python/README.md b/python/README.md
index 3f17fdb98a081..c020d84b01ffd 100644
--- a/python/README.md
+++ b/python/README.md
@@ -22,11 +22,11 @@ This packaging is currently experimental and may change in future versions (alth
Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at
["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
-The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html).
+The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html).
**NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors.
## Python Requirements
-At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
+At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow).
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 09898f29950ed..b8e079483c90c 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
-export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip)
+export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip)
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip
deleted file mode 100644
index 2f8edcc0c0b88..0000000000000
Binary files a/python/lib/py4j-0.10.6-src.zip and /dev/null differ
diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip
new file mode 100644
index 0000000000000..128e321078793
Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 4d142c91629cc..58218918693ca 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -54,6 +54,7 @@
from pyspark.taskcontext import TaskContext
from pyspark.profiler import Profiler, BasicProfiler
from pyspark.version import __version__
+from pyspark._globals import _NoValue
def since(version):
diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py
new file mode 100644
index 0000000000000..8e6099db09963
--- /dev/null
+++ b/python/pyspark/_globals.py
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Module defining global singleton classes.
+
+This module raises a RuntimeError if an attempt to reload it is made. In that
+way the identities of the classes defined here are fixed and will remain so
+even if pyspark itself is reloaded. In particular, a function like the following
+will still work correctly after pyspark is reloaded:
+
+ def foo(arg=pyspark._NoValue):
+ if arg is pyspark._NoValue:
+ ...
+
+See gh-7844 for a discussion of the reload problem that motivated this module.
+
+Note that this approach is taken after from NumPy.
+"""
+
+__ALL__ = ['_NoValue']
+
+
+# Disallow reloading this module so as to preserve the identities of the
+# classes defined here.
+if '_is_loaded' in globals():
+ raise RuntimeError('Reloading pyspark._globals is not allowed')
+_is_loaded = True
+
+
+class _NoValueType(object):
+ """Special keyword value.
+
+ The instance of this class may be used as the default value assigned to a
+ deprecated keyword in order to check if it has been given a user defined
+ value.
+
+ This class was copied from NumPy.
+ """
+ __instance = None
+
+ def __new__(cls):
+ # ensure that only one instance exists
+ if not cls.__instance:
+ cls.__instance = super(_NoValueType, cls).__new__(cls)
+ return cls.__instance
+
+ # needed for python 2 to preserve identity through a pickle
+ def __reduce__(self):
+ return (self.__class__, ())
+
+ def __repr__(self):
+ return ""
+
+
+_NoValue = _NoValueType()
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 6ef8cf53cc747..f730d290273fe 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -94,7 +94,6 @@
else:
import socketserver as SocketServer
import threading
-from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, PickleSerializer
@@ -266,4 +265,4 @@ def _start_update_server():
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 02fc515fb824a..b3dfc99962a35 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -162,4 +162,4 @@ def clear(self):
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 40e91a2d0655d..88519d7311fcc 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -57,7 +57,6 @@
import types
import weakref
-from pyspark.util import _exception_message
if sys.version < '3':
from pickle import Pickler
@@ -181,6 +180,32 @@ def _builtin_type(name):
return getattr(types, name)
+def _make__new__factory(type_):
+ def _factory():
+ return type_.__new__
+ return _factory
+
+
+# NOTE: These need to be module globals so that they're pickleable as globals.
+_get_dict_new = _make__new__factory(dict)
+_get_frozenset_new = _make__new__factory(frozenset)
+_get_list_new = _make__new__factory(list)
+_get_set_new = _make__new__factory(set)
+_get_tuple_new = _make__new__factory(tuple)
+_get_object_new = _make__new__factory(object)
+
+# Pre-defined set of builtin_function_or_method instances that can be
+# serialized.
+_BUILTIN_TYPE_CONSTRUCTORS = {
+ dict.__new__: _get_dict_new,
+ frozenset.__new__: _get_frozenset_new,
+ set.__new__: _get_set_new,
+ list.__new__: _get_list_new,
+ tuple.__new__: _get_tuple_new,
+ object.__new__: _get_object_new,
+}
+
+
if sys.version_info < (3, 4):
def _walk_global_ops(code):
"""
@@ -237,29 +262,17 @@ def dump(self, obj):
if 'recursion' in e.args[0]:
msg = """Could not pickle object as excessively deep recursion required."""
raise pickle.PicklingError(msg)
- except pickle.PickleError:
- raise
- except Exception as e:
- emsg = _exception_message(e)
- if "'i' format requires" in emsg:
- msg = "Object too large to serialize: %s" % emsg
else:
- msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
- print_exec(sys.stderr)
- raise pickle.PicklingError(msg)
-
+ raise
def save_memoryview(self, obj):
- """Fallback to save_string"""
- Pickler.save_string(self, str(obj))
-
- def save_buffer(self, obj):
- """Fallback to save_string"""
- Pickler.save_string(self,str(obj))
- if PY3:
- dispatch[memoryview] = save_memoryview
- else:
- dispatch[buffer] = save_buffer
+ self.save(obj.tobytes())
+ dispatch[memoryview] = save_memoryview
+
+ if not PY3:
+ def save_buffer(self, obj):
+ self.save(str(obj))
+ dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3
def save_unsupported(self, obj):
raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
@@ -318,6 +331,24 @@ def save_function(self, obj, name=None):
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
+ try:
+ should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS
+ except TypeError:
+ # Methods of builtin types aren't hashable in python 2.
+ should_special_case = False
+
+ if should_special_case:
+ # We keep a special-cased cache of built-in type constructors at
+ # global scope, because these functions are structured very
+ # differently in different python versions and implementations (for
+ # example, they're instances of types.BuiltinFunctionType in
+ # CPython, but they're ordinary types.FunctionType instances in
+ # PyPy).
+ #
+ # If the function we've received is in that cache, we just
+ # serialize it as a lookup into the cache.
+ return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj)
+
write = self.write
if name is None:
@@ -344,7 +375,7 @@ def save_function(self, obj, name=None):
return self.save_global(obj, name)
# a builtin_function_or_method which comes in as an attribute of some
- # object (e.g., object.__new__, itertools.chain.from_iterable) will end
+ # object (e.g., itertools.chain.from_iterable) will end
# up with modname "__main__" and so end up here. But these functions
# have no __code__ attribute in CPython, so the handling for
# user-defined functions below will fail.
@@ -352,16 +383,13 @@ def save_function(self, obj, name=None):
# for different python versions.
if not hasattr(obj, '__code__'):
if PY3:
- if sys.version_info < (3, 4):
- raise pickle.PicklingError("Can't pickle %r" % obj)
- else:
- rv = obj.__reduce_ex__(self.proto)
+ rv = obj.__reduce_ex__(self.proto)
else:
if hasattr(obj, '__self__'):
rv = (getattr, (obj.__self__, name))
else:
raise pickle.PicklingError("Can't pickle %r" % obj)
- return Pickler.save_reduce(self, obj=obj, *rv)
+ return self.save_reduce(obj=obj, *rv)
# if func is lambda, def'ed at prompt, is in main, or is nested, then
# we'll pickle the actual function object rather than simply saving a
@@ -420,20 +448,18 @@ def save_dynamic_class(self, obj):
from global modules.
"""
clsdict = dict(obj.__dict__) # copy dict proxy to a dict
- if not isinstance(clsdict.get('__dict__', None), property):
- # don't extract dict that are properties
- clsdict.pop('__dict__', None)
- clsdict.pop('__weakref__', None)
-
- # hack as __new__ is stored differently in the __dict__
- new_override = clsdict.get('__new__', None)
- if new_override:
- clsdict['__new__'] = obj.__new__
-
- # namedtuple is a special case for Spark where we use the _load_namedtuple function
- if getattr(obj, '_is_namedtuple_', False):
- self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields))
- return
+ clsdict.pop('__weakref__', None)
+
+ # On PyPy, __doc__ is a readonly attribute, so we need to include it in
+ # the initial skeleton class. This is safe because we know that the
+ # doc can't participate in a cycle with the original class.
+ type_kwargs = {'__doc__': clsdict.pop('__doc__', None)}
+
+ # If type overrides __dict__ as a property, include it in the type kwargs.
+ # In Python 2, we can't set this attribute after construction.
+ __dict__ = clsdict.pop('__dict__', None)
+ if isinstance(__dict__, property):
+ type_kwargs['__dict__'] = __dict__
save = self.save
write = self.write
@@ -453,23 +479,12 @@ def save_dynamic_class(self, obj):
# Push the rehydration function.
save(_rehydrate_skeleton_class)
- # Mark the start of the args for the rehydration function.
+ # Mark the start of the args tuple for the rehydration function.
write(pickle.MARK)
- # On PyPy, __doc__ is a readonly attribute, so we need to include it in
- # the initial skeleton class. This is safe because we know that the
- # doc can't participate in a cycle with the original class.
- doc_dict = {'__doc__': clsdict.pop('__doc__', None)}
-
- # Create and memoize an empty class with obj's name and bases.
- save(type(obj))
- save((
- obj.__name__,
- obj.__bases__,
- doc_dict,
- ))
- write(pickle.REDUCE)
- self.memoize(obj)
+ # Create and memoize an skeleton class with obj's name and bases.
+ tp = type(obj)
+ self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj)
# Now save the rest of obj's __dict__. Any references to obj
# encountered while saving will point to the skeleton class.
@@ -522,17 +537,22 @@ def save_function_tuple(self, func):
self.memoize(func)
# save the rest of the func data needed by _fill_function
- save(f_globals)
- save(defaults)
- save(dct)
- save(func.__module__)
- save(closure_values)
+ state = {
+ 'globals': f_globals,
+ 'defaults': defaults,
+ 'dict': dct,
+ 'module': func.__module__,
+ 'closure_values': closure_values,
+ }
+ if hasattr(func, '__qualname__'):
+ state['qualname'] = func.__qualname__
+ save(state)
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple
_extract_code_globals_cache = (
weakref.WeakKeyDictionary()
- if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info")
+ if not hasattr(sys, "pypy_version_info")
else {})
@classmethod
@@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack):
The name of this method is somewhat misleading: all types get
dispatched here.
"""
- if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
- if obj in _BUILTIN_TYPE_NAMES:
- return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
-
- if name is None:
- name = obj.__name__
-
- modname = getattr(obj, "__module__", None)
- if modname is None:
- try:
- # whichmodule() could fail, see
- # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling
- modname = pickle.whichmodule(obj, name)
- except Exception:
- modname = '__main__'
+ if obj.__module__ == "__main__":
+ return self.save_dynamic_class(obj)
- if modname == '__main__':
- themodule = None
- else:
- __import__(modname)
- themodule = sys.modules[modname]
- self.modules.add(themodule)
+ try:
+ return Pickler.save_global(self, obj, name=name)
+ except Exception:
+ if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
+ if obj in _BUILTIN_TYPE_NAMES:
+ return self.save_reduce(
+ _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
- if hasattr(themodule, name) and getattr(themodule, name) is obj:
- return Pickler.save_global(self, obj, name)
+ typ = type(obj)
+ if typ is not obj and isinstance(obj, (type, types.ClassType)):
+ return self.save_dynamic_class(obj)
- typ = type(obj)
- if typ is not obj and isinstance(obj, (type, types.ClassType)):
- self.save_dynamic_class(obj)
- else:
- raise pickle.PicklingError("Can't pickle %r" % obj)
+ raise
dispatch[type] = save_global
dispatch[types.ClassType] = save_global
@@ -709,12 +714,7 @@ def save_property(self, obj):
dispatch[property] = save_property
def save_classmethod(self, obj):
- try:
- orig_func = obj.__func__
- except AttributeError: # Python 2.6
- orig_func = obj.__get__(None, object)
- if isinstance(obj, classmethod):
- orig_func = orig_func.__func__ # Unbind
+ orig_func = obj.__func__
self.save_reduce(type(obj), (orig_func,), obj=obj)
dispatch[classmethod] = save_classmethod
dispatch[staticmethod] = save_classmethod
@@ -754,64 +754,6 @@ def __getattribute__(self, item):
if type(operator.attrgetter) is type:
dispatch[operator.attrgetter] = save_attrgetter
- def save_reduce(self, func, args, state=None,
- listitems=None, dictitems=None, obj=None):
- # Assert that args is a tuple or None
- if not isinstance(args, tuple):
- raise pickle.PicklingError("args from reduce() should be a tuple")
-
- # Assert that func is callable
- if not hasattr(func, '__call__'):
- raise pickle.PicklingError("func from reduce should be callable")
-
- save = self.save
- write = self.write
-
- # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
- if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
- cls = args[0]
- if not hasattr(cls, "__new__"):
- raise pickle.PicklingError(
- "args[0] from __newobj__ args has no __new__")
- if obj is not None and cls is not obj.__class__:
- raise pickle.PicklingError(
- "args[0] from __newobj__ args has the wrong class")
- args = args[1:]
- save(cls)
-
- save(args)
- write(pickle.NEWOBJ)
- else:
- save(func)
- save(args)
- write(pickle.REDUCE)
-
- if obj is not None:
- self.memoize(obj)
-
- # More new special cases (that work with older protocols as
- # well): when __reduce__ returns a tuple with 4 or 5 items,
- # the 4th and 5th item should be iterators that provide list
- # items and dict items (as (key, value) tuples), or None.
-
- if listitems is not None:
- self._batch_appends(listitems)
-
- if dictitems is not None:
- self._batch_setitems(dictitems)
-
- if state is not None:
- save(state)
- write(pickle.BUILD)
-
- def save_partial(self, obj):
- """Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
- self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
-
- if sys.version_info < (2,7): # 2.7 supports partial pickling
- dispatch[partial] = save_partial
-
-
def save_file(self, obj):
"""Save a file"""
try:
@@ -859,31 +801,34 @@ def save_ellipsis(self, obj):
def save_not_implemented(self, obj):
self.save_reduce(_gen_not_implemented, ())
- if PY3:
- dispatch[io.TextIOWrapper] = save_file
- else:
+ try: # Python 2
dispatch[file] = save_file
+ except NameError: # Python 3
+ dispatch[io.TextIOWrapper] = save_file
dispatch[type(Ellipsis)] = save_ellipsis
dispatch[type(NotImplemented)] = save_not_implemented
- # WeakSet was added in 2.7.
- if hasattr(weakref, 'WeakSet'):
- def save_weakset(self, obj):
- self.save_reduce(weakref.WeakSet, (list(obj),))
+ def save_weakset(self, obj):
+ self.save_reduce(weakref.WeakSet, (list(obj),))
- dispatch[weakref.WeakSet] = save_weakset
-
- """Special functions for Add-on libraries"""
- def inject_addons(self):
- """Plug in system. Register additional pickling functions if modules already loaded"""
- pass
+ dispatch[weakref.WeakSet] = save_weakset
def save_logger(self, obj):
self.save_reduce(logging.getLogger, (obj.name,), obj=obj)
dispatch[logging.Logger] = save_logger
+ def save_root_logger(self, obj):
+ self.save_reduce(logging.getLogger, (), obj=obj)
+
+ dispatch[logging.RootLogger] = save_root_logger
+
+ """Special functions for Add-on libraries"""
+ def inject_addons(self):
+ """Plug in system. Register additional pickling functions if modules already loaded"""
+ pass
+
# Tornado support
@@ -913,11 +858,12 @@ def dump(obj, file, protocol=2):
def dumps(obj, protocol=2):
file = StringIO()
-
- cp = CloudPickler(file,protocol)
- cp.dump(obj)
-
- return file.getvalue()
+ try:
+ cp = CloudPickler(file,protocol)
+ cp.dump(obj)
+ return file.getvalue()
+ finally:
+ file.close()
# including pickles unloading functions in this namespace
load = pickle.load
@@ -1019,18 +965,40 @@ def __reduce__(cls):
return cls.__name__
-def _fill_function(func, globals, defaults, dict, module, closure_values):
- """ Fills in the rest of function data into the skeleton function object
- that were created via _make_skel_func().
+def _fill_function(*args):
+ """Fills in the rest of function data into the skeleton function object
+
+ The skeleton itself is create by _make_skel_func().
"""
- func.__globals__.update(globals)
- func.__defaults__ = defaults
- func.__dict__ = dict
- func.__module__ = module
+ if len(args) == 2:
+ func = args[0]
+ state = args[1]
+ elif len(args) == 5:
+ # Backwards compat for cloudpickle v0.4.0, after which the `module`
+ # argument was introduced
+ func = args[0]
+ keys = ['globals', 'defaults', 'dict', 'closure_values']
+ state = dict(zip(keys, args[1:]))
+ elif len(args) == 6:
+ # Backwards compat for cloudpickle v0.4.1, after which the function
+ # state was passed as a dict to the _fill_function it-self.
+ func = args[0]
+ keys = ['globals', 'defaults', 'dict', 'module', 'closure_values']
+ state = dict(zip(keys, args[1:]))
+ else:
+ raise ValueError('Unexpected _fill_value arguments: %r' % (args,))
+
+ func.__globals__.update(state['globals'])
+ func.__defaults__ = state['defaults']
+ func.__dict__ = state['dict']
+ if 'module' in state:
+ func.__module__ = state['module']
+ if 'qualname' in state:
+ func.__qualname__ = state['qualname']
cells = func.__closure__
if cells is not None:
- for cell, value in zip(cells, closure_values):
+ for cell, value in zip(cells, state['closure_values']):
if value is not _empty_cell_value:
cell_set(cell, value)
@@ -1087,13 +1055,6 @@ def _find_module(mod_name):
file.close()
return path, description
-def _load_namedtuple(name, fields):
- """
- Loads a class generated by namedtuple
- """
- from collections import namedtuple
- return namedtuple(name, fields)
-
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index 491b3a81972bc..ab429d9ab10de 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -217,7 +217,7 @@ def _test():
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 24905f1c97b21..ede3b6af0a8cf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
- if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
- self._python_includes.append(filename)
- sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
+ try:
+ filepath = os.path.join(SparkFiles.getRootDirectory(), filename)
+ if not os.path.exists(filepath):
+ # In case of YARN with shell mode, 'spark.submit.pyFiles' files are
+ # not added via SparkContext.addFile. Here we check if the file exists,
+ # try to copy and then add it to the path. See SPARK-21945.
+ shutil.copyfile(path, filepath)
+ if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
+ self._python_includes.append(filename)
+ sys.path.insert(1, filepath)
+ except Exception:
+ warnings.warn(
+ "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to "
+ "Python path:\n %s" % (path, "\n ".join(sys.path)),
+ RuntimeWarning)
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
@@ -998,8 +1010,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
- return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
+ sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
+ return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """
@@ -1035,7 +1047,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 7f06d4288c872..ebdd665e349c5 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -29,7 +29,7 @@
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
-from pyspark.serializers import read_int, write_int
+from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
def compute_real_exit_code(exit_code):
@@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code):
return 1
-def worker(sock):
+def worker(sock, authenticated):
"""
Called by a worker process after the fork().
"""
@@ -56,6 +56,18 @@ def worker(sock):
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
+
+ if not authenticated:
+ client_secret = UTF8Deserializer().loads(infile)
+ if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
+ write_with_length("ok".encode("utf-8"), outfile)
+ outfile.flush()
+ else:
+ write_with_length("err".encode("utf-8"), outfile)
+ outfile.flush()
+ sock.close()
+ return 1
+
exit_code = 0
try:
worker_main(infile, outfile)
@@ -89,7 +101,7 @@ def shutdown(code):
signal.signal(SIGTERM, SIG_DFL)
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
- exit(code)
+ sys.exit(code)
def handle_sigterm(*args):
shutdown(1)
@@ -153,8 +165,11 @@ def handle_sigterm(*args):
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
+ authenticated = False
while True:
- code = worker(sock)
+ code = worker(sock, authenticated)
+ if code == 0:
+ authenticated = True
if not reuse or code:
# wait for closing
try:
diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py
index 212a618b767ab..9cf0e8c8d2fe9 100755
--- a/python/pyspark/find_spark_home.py
+++ b/python/pyspark/find_spark_home.py
@@ -68,7 +68,7 @@ def is_spark_home(path):
return next(path for path in paths if is_spark_home(path))
except StopIteration:
print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
print(_find_spark_home())
diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py
index b27e91a4cc251..6af084adcf373 100644
--- a/python/pyspark/heapq3.py
+++ b/python/pyspark/heapq3.py
@@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None):
if __name__ == "__main__":
import doctest
+ import sys
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3e704fe9bf6ec..0afbe9dc6aa3e 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -21,16 +21,19 @@
import select
import signal
import shlex
+import shutil
import socket
import platform
+import tempfile
+import time
from subprocess import Popen, PIPE
if sys.version >= '3':
xrange = range
-from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from pyspark.find_spark_home import _find_spark_home
-from pyspark.serializers import read_int
+from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
def launch_gateway(conf=None):
@@ -41,6 +44,7 @@ def launch_gateway(conf=None):
"""
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
+ gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
else:
SPARK_HOME = _find_spark_home()
# Launch the Py4j gateway using Spark's run command so that we pick up the
@@ -59,40 +63,40 @@ def launch_gateway(conf=None):
])
command = command + shlex.split(submit_args)
- # Start a socket that will be used by PythonGatewayServer to communicate its port to us
- callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- callback_socket.bind(('127.0.0.1', 0))
- callback_socket.listen(1)
- callback_host, callback_port = callback_socket.getsockname()
- env = dict(os.environ)
- env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
- env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
-
- # Launch the Java gateway.
- # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
- if not on_windows:
- # Don't send ctrl-c / SIGINT to the Java gateway:
- def preexec_func():
- signal.signal(signal.SIGINT, signal.SIG_IGN)
- proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
- else:
- # preexec_fn not supported on Windows
- proc = Popen(command, stdin=PIPE, env=env)
-
- gateway_port = None
- # We use select() here in order to avoid blocking indefinitely if the subprocess dies
- # before connecting
- while gateway_port is None and proc.poll() is None:
- timeout = 1 # (seconds)
- readable, _, _ = select.select([callback_socket], [], [], timeout)
- if callback_socket in readable:
- gateway_connection = callback_socket.accept()[0]
- # Determine which ephemeral port the server started on:
- gateway_port = read_int(gateway_connection.makefile(mode="rb"))
- gateway_connection.close()
- callback_socket.close()
- if gateway_port is None:
- raise Exception("Java gateway process exited before sending the driver its port number")
+ # Create a temporary directory where the gateway server should write the connection
+ # information.
+ conn_info_dir = tempfile.mkdtemp()
+ try:
+ fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
+ os.close(fd)
+ os.unlink(conn_info_file)
+
+ env = dict(os.environ)
+ env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
+
+ # Launch the Java gateway.
+ # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
+ if not on_windows:
+ # Don't send ctrl-c / SIGINT to the Java gateway:
+ def preexec_func():
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
+ else:
+ # preexec_fn not supported on Windows
+ proc = Popen(command, stdin=PIPE, env=env)
+
+ # Wait for the file to appear, or for the process to exit, whichever happens first.
+ while not proc.poll() and not os.path.isfile(conn_info_file):
+ time.sleep(0.1)
+
+ if not os.path.isfile(conn_info_file):
+ raise Exception("Java gateway process exited before sending its port number")
+
+ with open(conn_info_file, "rb") as info:
+ gateway_port = read_int(info)
+ gateway_secret = UTF8Deserializer().loads(info)
+ finally:
+ shutil.rmtree(conn_info_dir)
# In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
@@ -111,7 +115,9 @@ def killChild():
atexit.register(killChild)
# Connect to the gateway
- gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
+ gateway = JavaGateway(
+ gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
+ auto_convert=True))
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -126,3 +132,16 @@ def killChild():
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+
+def do_server_auth(conn, auth_secret):
+ """
+ Performs the authentication protocol defined by the SocketAuthHelper class on the given
+ file-like object 'conn'.
+ """
+ write_with_length(auth_secret.encode("utf-8"), conn)
+ conn.flush()
+ reply = UTF8Deserializer().loads(conn)
+ if reply != "ok":
+ conn.close()
+ raise Exception("Unexpected reply from iterator server.")
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 129d7d68f7cbb..d99a25390db15 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -21,5 +21,11 @@
"""
from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel
+from pyspark.ml import classification, clustering, evaluation, feature, fpm, \
+ image, pipeline, recommendation, regression, stat, tuning, util, linalg, param
-__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
+__all__ = [
+ "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel",
+ "classification", "clustering", "evaluation", "feature", "fpm", "image",
+ "recommendation", "regression", "stat", "tuning", "util", "linalg", "param",
+]
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 27ad1e80aa0d3..1754c48937a62 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -16,6 +16,7 @@
#
import operator
+import sys
from multiprocessing.pool import ThreadPool
from pyspark import since, keyword_only
@@ -1130,6 +1131,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)
+ @since("2.4.0")
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+ """
+ return self._set(featureSubsetStrategy=value)
+
class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
@@ -1192,6 +1200,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
+ >>> gbt.getFeatureSubsetStrategy()
+ 'all'
>>> model = gbt.fit(td)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
@@ -1221,6 +1231,12 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
+ ... ["indexed", "features"])
+ >>> model.evaluateEachIteration(validation)
+ [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
+ >>> model.numClasses
+ 2
.. versionadded:: 1.4.0
"""
@@ -1239,19 +1255,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
- maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0):
+ maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
+ featureSubsetStrategy="all"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0)
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
+ featureSubsetStrategy="all")
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0)
+ lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
+ featureSubsetStrategy="all")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1260,12 +1279,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0):
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
+ featureSubsetStrategy="all"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0)
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
+ featureSubsetStrategy="all")
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self._input_kwargs
@@ -1288,8 +1309,15 @@ def getLossType(self):
"""
return self.getOrDefault(self.lossType)
+ @since("2.4.0")
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+ """
+ return self._set(featureSubsetStrategy=value)
-class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
+
+class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by GBTClassifier.
@@ -1318,6 +1346,17 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ return self._call_java("evaluateEachIteration", dataset)
+
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
@@ -1542,12 +1581,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
- rawPredicitionCol="rawPrediction"):
+ rawPredictionCol="rawPrediction"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
- rawPredicitionCol="rawPrediction")
+ rawPredictionCol="rawPrediction")
"""
super(MultilayerPerceptronClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@@ -1561,12 +1600,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03,
solver="l-bfgs", initialWeights=None, probabilityCol="probability",
- rawPredicitionCol="rawPrediction"):
+ rawPredictionCol="rawPrediction"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \
solver="l-bfgs", initialWeights=None, probabilityCol="probability", \
- rawPredicitionCol="rawPrediction"):
+ rawPredictionCol="rawPrediction"):
Sets params for MultilayerPerceptronClassifier.
"""
kwargs = self._input_kwargs
@@ -2043,4 +2082,4 @@ def _to_java(self):
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 66fb00508522e..4aa1cf84b5824 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -15,16 +15,19 @@
# limitations under the License.
#
+import sys
+
from pyspark import since, keyword_only
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
'KMeans', 'KMeansModel',
'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary',
- 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel']
+ 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering']
class ClusteringSummary(JavaWrapper):
@@ -403,17 +406,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
typeConverter=TypeConverters.toString)
initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " +
"initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt)
+ distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " +
+ "Supported options: 'euclidean' and 'cosine'.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", predictionCol="prediction", k=2,
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None):
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
+ distanceMeasure="euclidean"):
"""
__init__(self, featuresCol="features", predictionCol="prediction", k=2, \
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None)
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
+ distanceMeasure="euclidean")
"""
super(KMeans, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
- self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20)
+ self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20,
+ distanceMeasure="euclidean")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -423,10 +432,12 @@ def _create_model(self, java_model):
@keyword_only
@since("1.5.0")
def setParams(self, featuresCol="features", predictionCol="prediction", k=2,
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None):
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None,
+ distanceMeasure="euclidean"):
"""
setParams(self, featuresCol="features", predictionCol="prediction", k=2, \
- initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None)
+ initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \
+ distanceMeasure="euclidean")
Sets params for KMeans.
"""
@@ -475,6 +486,20 @@ def getInitSteps(self):
"""
return self.getOrDefault(self.initSteps)
+ @since("2.4.0")
+ def setDistanceMeasure(self, value):
+ """
+ Sets the value of :py:attr:`distanceMeasure`.
+ """
+ return self._set(distanceMeasure=value)
+
+ @since("2.4.0")
+ def getDistanceMeasure(self):
+ """
+ Gets the value of `distanceMeasure`
+ """
+ return self.getOrDefault(self.distanceMeasure)
+
class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
@@ -812,7 +837,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter
Terminology:
- - "term" = "word": an el
+ - "term" = "word": an element of the vocabulary
- "token": instance of a term appearing in a document
- "topic": multinomial distribution over terms representing some concept
- "document": one piece of text, corresponding to one row in the input data
@@ -914,7 +939,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte
k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
subsamplingRate=0.05, optimizeDocConcentration=True,\
docConcentration=None, topicConcentration=None,\
- topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+ topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
"""
super(LDA, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid)
@@ -943,7 +968,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt
k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\
subsamplingRate=0.05, optimizeDocConcentration=True,\
docConcentration=None, topicConcentration=None,\
- topicDistributionCol="topicDistribution", keepLastCheckpoint=True):
+ topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
Sets params for LDA.
"""
@@ -1132,6 +1157,179 @@ def getKeepLastCheckpoint(self):
return self.getOrDefault(self.keepLastCheckpoint)
+@inherit_doc
+class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReadable,
+ JavaMLWritable):
+ """
+ .. note:: Experimental
+
+ Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
+ Lin and Cohen. From the abstract:
+ PIC finds a very low-dimensional embedding of a dataset using truncated power
+ iteration on a normalized pair-wise similarity matrix of the data.
+
+ This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method
+ to run the PowerIterationClustering algorithm.
+
+ .. seealso:: `Wikipedia on Spectral clustering \
+ `_
+
+ >>> data = [(1, 0, 0.5), \
+ (2, 0, 0.5), (2, 1, 0.7), \
+ (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \
+ (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \
+ (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)]
+ >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight")
+ >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight")
+ >>> assignments = pic.assignClusters(df)
+ >>> assignments.sort(assignments.id).show(truncate=False)
+ +---+-------+
+ |id |cluster|
+ +---+-------+
+ |0 |1 |
+ |1 |1 |
+ |2 |1 |
+ |3 |1 |
+ |4 |1 |
+ |5 |0 |
+ +---+-------+
+ ...
+ >>> pic_path = temp_path + "/pic"
+ >>> pic.save(pic_path)
+ >>> pic2 = PowerIterationClustering.load(pic_path)
+ >>> pic2.getK()
+ 2
+ >>> pic2.getMaxIter()
+ 40
+
+ .. versionadded:: 2.4.0
+ """
+
+ k = Param(Params._dummy(), "k",
+ "The number of clusters to create. Must be > 1.",
+ typeConverter=TypeConverters.toInt)
+ initMode = Param(Params._dummy(), "initMode",
+ "The initialization algorithm. This can be either " +
+ "'random' to use a random vector as vertex properties, or 'degree' to use " +
+ "a normalized sum of similarities with other vertices. Supported options: " +
+ "'random' and 'degree'.",
+ typeConverter=TypeConverters.toString)
+ srcCol = Param(Params._dummy(), "srcCol",
+ "Name of the input column for source vertex IDs.",
+ typeConverter=TypeConverters.toString)
+ dstCol = Param(Params._dummy(), "dstCol",
+ "Name of the input column for destination vertex IDs.",
+ typeConverter=TypeConverters.toString)
+
+ @keyword_only
+ def __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
+ weightCol=None):
+ """
+ __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
+ weightCol=None)
+ """
+ super(PowerIterationClustering, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid)
+ self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst")
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.4.0")
+ def setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",
+ weightCol=None):
+ """
+ setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\
+ weightCol=None)
+ Sets params for PowerIterationClustering.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.4.0")
+ def setK(self, value):
+ """
+ Sets the value of :py:attr:`k`.
+ """
+ return self._set(k=value)
+
+ @since("2.4.0")
+ def getK(self):
+ """
+ Gets the value of :py:attr:`k` or its default value.
+ """
+ return self.getOrDefault(self.k)
+
+ @since("2.4.0")
+ def setInitMode(self, value):
+ """
+ Sets the value of :py:attr:`initMode`.
+ """
+ return self._set(initMode=value)
+
+ @since("2.4.0")
+ def getInitMode(self):
+ """
+ Gets the value of :py:attr:`initMode` or its default value.
+ """
+ return self.getOrDefault(self.initMode)
+
+ @since("2.4.0")
+ def setSrcCol(self, value):
+ """
+ Sets the value of :py:attr:`srcCol`.
+ """
+ return self._set(srcCol=value)
+
+ @since("2.4.0")
+ def getSrcCol(self):
+ """
+ Gets the value of :py:attr:`srcCol` or its default value.
+ """
+ return self.getOrDefault(self.srcCol)
+
+ @since("2.4.0")
+ def setDstCol(self, value):
+ """
+ Sets the value of :py:attr:`dstCol`.
+ """
+ return self._set(dstCol=value)
+
+ @since("2.4.0")
+ def getDstCol(self):
+ """
+ Gets the value of :py:attr:`dstCol` or its default value.
+ """
+ return self.getOrDefault(self.dstCol)
+
+ @since("2.4.0")
+ def assignClusters(self, dataset):
+ """
+ Run the PIC algorithm and returns a cluster assignment for each input vertex.
+
+ :param dataset:
+ A dataset with columns src, dst, weight representing the affinity matrix,
+ which is the matrix A in the PIC paper. Suppose the src column value is i,
+ the dst column value is j, the weight column value is similarity s,,ij,,
+ which must be nonnegative. This is a symmetric matrix and hence
+ s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be
+ either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are
+ ignored, because we assume s,,ij,, = 0.0.
+
+ :return:
+ A dataset that contains columns of vertex id and the corresponding cluster for
+ the id. The schema of it will be:
+ - id: Long
+ - cluster: Int
+
+ .. versionadded:: 2.4.0
+ """
+ self._transfer_params_to_java()
+ jdf = self._java_obj.assignClusters(dataset._jdf)
+ return DataFrame(jdf, dataset.sql_ctx)
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.clustering
@@ -1159,4 +1357,4 @@ def getKeepLastCheckpoint(self):
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 0cbce9b40048f..8eaf07645a37f 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import sys
from abc import abstractmethod, ABCMeta
from pyspark import since, keyword_only
@@ -362,18 +363,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol,
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (silhouette)",
typeConverter=TypeConverters.toString)
+ distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " +
+ "Supported options: 'squaredEuclidean' and 'cosine'.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, predictionCol="prediction", featuresCol="features",
- metricName="silhouette"):
+ metricName="silhouette", distanceMeasure="squaredEuclidean"):
"""
__init__(self, predictionCol="prediction", featuresCol="features", \
- metricName="silhouette")
+ metricName="silhouette", distanceMeasure="squaredEuclidean")
"""
super(ClusteringEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid)
- self._setDefault(metricName="silhouette")
+ self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean")
kwargs = self._input_kwargs
self._set(**kwargs)
@@ -394,15 +398,30 @@ def getMetricName(self):
@keyword_only
@since("2.3.0")
def setParams(self, predictionCol="prediction", featuresCol="features",
- metricName="silhouette"):
+ metricName="silhouette", distanceMeasure="squaredEuclidean"):
"""
setParams(self, predictionCol="prediction", featuresCol="features", \
- metricName="silhouette")
+ metricName="silhouette", distanceMeasure="squaredEuclidean")
Sets params for clustering evaluator.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
+ @since("2.4.0")
+ def setDistanceMeasure(self, value):
+ """
+ Sets the value of :py:attr:`distanceMeasure`.
+ """
+ return self._set(distanceMeasure=value)
+
+ @since("2.4.0")
+ def getDistanceMeasure(self):
+ """
+ Gets the value of `distanceMeasure`
+ """
+ return self.getOrDefault(self.distanceMeasure)
+
+
if __name__ == "__main__":
import doctest
import tempfile
@@ -428,4 +447,4 @@ def setParams(self, predictionCol="prediction", featuresCol="features",
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index da85ba761a145..14800d4d9327a 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -19,12 +19,12 @@
if sys.version > '3':
basestring = str
-from pyspark import since, keyword_only
+from pyspark import since, keyword_only, SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.linalg import _convert_to_vector
from pyspark.ml.param.shared import *
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
from pyspark.ml.common import inherit_doc
__all__ = ['Binarizer',
@@ -403,8 +403,84 @@ def getSplits(self):
return self.getOrDefault(self.splits)
+class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
+ """
+ Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
+ """
+
+ minTF = Param(
+ Params._dummy(), "minTF", "Filter to ignore rare words in" +
+ " a document. For each document, terms with frequency/count less than the given" +
+ " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
+ " times the term must appear in the document); if this is a double in [0,1), then this " +
+ "specifies a fraction (out of the document's token count). Note that the parameter is " +
+ "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
+ typeConverter=TypeConverters.toFloat)
+ minDF = Param(
+ Params._dummy(), "minDF", "Specifies the minimum number of" +
+ " different documents a term must appear in to be included in the vocabulary." +
+ " If this is an integer >= 1, this specifies the number of documents the term must" +
+ " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
+ " Default 1.0", typeConverter=TypeConverters.toFloat)
+ maxDF = Param(
+ Params._dummy(), "maxDF", "Specifies the maximum number of" +
+ " different documents a term could appear in to be included in the vocabulary." +
+ " A term that appears more than the threshold will be ignored. If this is an" +
+ " integer >= 1, this specifies the maximum number of documents the term could appear in;" +
+ " if this is a double in [0,1), then this specifies the maximum" +
+ " fraction of documents the term could appear in." +
+ " Default (2^63) - 1", typeConverter=TypeConverters.toFloat)
+ vocabSize = Param(
+ Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
+ typeConverter=TypeConverters.toInt)
+ binary = Param(
+ Params._dummy(), "binary", "Binary toggle to control the output vector values." +
+ " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
+ " for discrete probabilistic models that model binary events rather than integer counts." +
+ " Default False", typeConverter=TypeConverters.toBoolean)
+
+ def __init__(self, *args):
+ super(_CountVectorizerParams, self).__init__(*args)
+ self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False)
+
+ @since("1.6.0")
+ def getMinTF(self):
+ """
+ Gets the value of minTF or its default value.
+ """
+ return self.getOrDefault(self.minTF)
+
+ @since("1.6.0")
+ def getMinDF(self):
+ """
+ Gets the value of minDF or its default value.
+ """
+ return self.getOrDefault(self.minDF)
+
+ @since("2.4.0")
+ def getMaxDF(self):
+ """
+ Gets the value of maxDF or its default value.
+ """
+ return self.getOrDefault(self.maxDF)
+
+ @since("1.6.0")
+ def getVocabSize(self):
+ """
+ Gets the value of vocabSize or its default value.
+ """
+ return self.getOrDefault(self.vocabSize)
+
+ @since("2.0.0")
+ def getBinary(self):
+ """
+ Gets the value of binary or its default value.
+ """
+ return self.getOrDefault(self.binary)
+
+
@inherit_doc
-class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
"""
Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
@@ -437,54 +513,40 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
>>> loadedModel = CountVectorizerModel.load(modelPath)
>>> loadedModel.vocabulary == model.vocabulary
True
+ >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
+ ... inputCol="raw", outputCol="vectors")
+ >>> fromVocabModel.transform(df).show(truncate=False)
+ +-----+---------------+-------------------------+
+ |label|raw |vectors |
+ +-----+---------------+-------------------------+
+ |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
+ |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+ +-----+---------------+-------------------------+
+ ...
.. versionadded:: 1.6.0
"""
- minTF = Param(
- Params._dummy(), "minTF", "Filter to ignore rare words in" +
- " a document. For each document, terms with frequency/count less than the given" +
- " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
- " times the term must appear in the document); if this is a double in [0,1), then this " +
- "specifies a fraction (out of the document's token count). Note that the parameter is " +
- "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
- typeConverter=TypeConverters.toFloat)
- minDF = Param(
- Params._dummy(), "minDF", "Specifies the minimum number of" +
- " different documents a term must appear in to be included in the vocabulary." +
- " If this is an integer >= 1, this specifies the number of documents the term must" +
- " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
- " Default 1.0", typeConverter=TypeConverters.toFloat)
- vocabSize = Param(
- Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
- typeConverter=TypeConverters.toInt)
- binary = Param(
- Params._dummy(), "binary", "Binary toggle to control the output vector values." +
- " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
- " for discrete probabilistic models that model binary events rather than integer counts." +
- " Default False", typeConverter=TypeConverters.toBoolean)
-
@keyword_only
- def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
- outputCol=None):
+ def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,
+ inputCol=None, outputCol=None):
"""
- __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
- outputCol=None)
+ __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\
+ inputCol=None,outputCol=None)
"""
super(CountVectorizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
self.uid)
- self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
- outputCol=None):
+ def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,
+ inputCol=None, outputCol=None):
"""
- setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
- outputCol=None)
+ setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\
+ inputCol=None, outputCol=None)
Set the params for the CountVectorizer
"""
kwargs = self._input_kwargs
@@ -497,13 +559,6 @@ def setMinTF(self, value):
"""
return self._set(minTF=value)
- @since("1.6.0")
- def getMinTF(self):
- """
- Gets the value of minTF or its default value.
- """
- return self.getOrDefault(self.minTF)
-
@since("1.6.0")
def setMinDF(self, value):
"""
@@ -511,12 +566,12 @@ def setMinDF(self, value):
"""
return self._set(minDF=value)
- @since("1.6.0")
- def getMinDF(self):
+ @since("2.4.0")
+ def setMaxDF(self, value):
"""
- Gets the value of minDF or its default value.
+ Sets the value of :py:attr:`maxDF`.
"""
- return self.getOrDefault(self.minDF)
+ return self._set(maxDF=value)
@since("1.6.0")
def setVocabSize(self, value):
@@ -525,13 +580,6 @@ def setVocabSize(self, value):
"""
return self._set(vocabSize=value)
- @since("1.6.0")
- def getVocabSize(self):
- """
- Gets the value of vocabSize or its default value.
- """
- return self.getOrDefault(self.vocabSize)
-
@since("2.0.0")
def setBinary(self, value):
"""
@@ -539,24 +587,40 @@ def setBinary(self, value):
"""
return self._set(binary=value)
- @since("2.0.0")
- def getBinary(self):
- """
- Gets the value of binary or its default value.
- """
- return self.getOrDefault(self.binary)
-
def _create_model(self, java_model):
return CountVectorizerModel(java_model)
-class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+@inherit_doc
+class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`CountVectorizer`.
.. versionadded:: 1.6.0
"""
+ @classmethod
+ @since("2.4.0")
+ def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
+ """
+ Construct the model directly from a vocabulary list of strings,
+ requires an active SparkContext.
+ """
+ sc = SparkContext._active_spark_context
+ java_class = sc._gateway.jvm.java.lang.String
+ jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
+ model = CountVectorizerModel._create_from_java_class(
+ "org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
+ model.setInputCol(inputCol)
+ if outputCol is not None:
+ model.setOutputCol(outputCol)
+ if minTF is not None:
+ model.setMinTF(minTF)
+ if binary is not None:
+ model.setBinary(binary)
+ model._set(vocabSize=len(vocabulary))
+ return model
+
@property
@since("1.6.0")
def vocabulary(self):
@@ -565,6 +629,20 @@ def vocabulary(self):
"""
return self._call_java("vocabulary")
+ @since("2.4.0")
+ def setMinTF(self, value):
+ """
+ Sets the value of :py:attr:`minTF`.
+ """
+ return self._set(minTF=value)
+
+ @since("2.4.0")
+ def setBinary(self, value):
+ """
+ Sets the value of :py:attr:`binary`.
+ """
+ return self._set(binary=value)
+
@inherit_doc
class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -741,9 +819,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures,
>>> df = spark.createDataFrame(data, cols)
>>> hasher = FeatureHasher(inputCols=cols, outputCol="features")
>>> hasher.transform(df).head().features
- SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0})
+ SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasher.setCategoricalCols(["real"]).transform(df).head().features
- SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0})
+ SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0})
>>> hasherPath = temp_path + "/hasher"
>>> hasher.save(hasherPath)
>>> loadedHasher = FeatureHasher.load(hasherPath)
@@ -2264,9 +2342,38 @@ def mean(self):
return self._call_java("mean")
+class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol):
+ """
+ Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
+ """
+
+ stringOrderType = Param(Params._dummy(), "stringOrderType",
+ "How to order labels of string column. The first label after " +
+ "ordering is assigned an index of 0. Supported options: " +
+ "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
+ typeConverter=TypeConverters.toString)
+
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
+ "or NULL values) in features and label column of string type. " +
+ "Options are 'skip' (filter out rows with invalid data), " +
+ "error (throw an error), or 'keep' (put invalid data " +
+ "in a special additional bucket, at index numLabels).",
+ typeConverter=TypeConverters.toString)
+
+ def __init__(self, *args):
+ super(_StringIndexerParams, self).__init__(*args)
+ self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
+
+ @since("2.3.0")
+ def getStringOrderType(self):
+ """
+ Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
+ """
+ return self.getOrDefault(self.stringOrderType)
+
+
@inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
- JavaMLWritable):
+class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
"""
A label indexer that maps a string column of labels to an ML column of label indices.
If the input column is numeric, we cast it to string and index the string values.
@@ -2310,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
... key=lambda x: x[0])
[(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
+ >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
+ ... inputCol="label", outputCol="indexed", handleInvalid="error")
+ >>> result = fromlabelsModel.transform(stringIndDf)
+ >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]),
+ ... key=lambda x: x[0])
+ [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
.. versionadded:: 1.4.0
"""
- stringOrderType = Param(Params._dummy(), "stringOrderType",
- "How to order labels of string column. The first label after " +
- "ordering is assigned an index of 0. Supported options: " +
- "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
- typeConverter=TypeConverters.toString)
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
- "or NULL values) in features and label column of string type. " +
- "Options are 'skip' (filter out rows with invalid data), " +
- "error (throw an error), or 'keep' (put invalid data " +
- "in a special additional bucket, at index numLabels).",
- typeConverter=TypeConverters.toString)
-
@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
stringOrderType="frequencyDesc"):
@@ -2336,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
"""
super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
- self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -2362,21 +2461,33 @@ def setStringOrderType(self, value):
"""
return self._set(stringOrderType=value)
- @since("2.3.0")
- def getStringOrderType(self):
- """
- Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
- """
- return self.getOrDefault(self.stringOrderType)
-
-class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`StringIndexer`.
.. versionadded:: 1.4.0
"""
+ @classmethod
+ @since("2.4.0")
+ def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
+ """
+ Construct the model directly from an array of label strings,
+ requires an active SparkContext.
+ """
+ sc = SparkContext._active_spark_context
+ java_class = sc._gateway.jvm.java.lang.String
+ jlabels = StringIndexerModel._new_java_array(labels, java_class)
+ model = StringIndexerModel._create_from_java_class(
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+ model.setInputCol(inputCol)
+ if outputCol is not None:
+ model.setOutputCol(outputCol)
+ if handleInvalid is not None:
+ model.setHandleInvalid(handleInvalid)
+ return model
+
@property
@since("1.5.0")
def labels(self):
@@ -2385,6 +2496,13 @@ def labels(self):
"""
return self._call_java("labels")
+ @since("2.4.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
@inherit_doc
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -2464,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
typeConverter=TypeConverters.toListString)
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
"comparison over the stop words", typeConverter=TypeConverters.toBoolean)
+ locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " +
+ "is true", typeConverter=TypeConverters.toString)
@keyword_only
- def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
+ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
+ locale=None):
"""
- __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
+ __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
+ locale=None)
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
- caseSensitive=False)
+ caseSensitive=False, locale=self._java_obj.getLocale())
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
+ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False,
+ locale=None):
"""
- setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
+ setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \
+ locale=None)
Sets params for this StopWordRemover.
"""
kwargs = self._input_kwargs
@@ -2516,6 +2640,20 @@ def getCaseSensitive(self):
"""
return self.getOrDefault(self.caseSensitive)
+ @since("2.4.0")
+ def setLocale(self, value):
+ """
+ Sets the value of :py:attr:`locale`.
+ """
+ return self._set(locale=value)
+
+ @since("2.4.0")
+ def getLocale(self):
+ """
+ Gets the value of :py:attr:`locale`.
+ """
+ return self.getOrDefault(self.locale)
+
@staticmethod
@since("2.0.0")
def loadDefaultStopWords(language):
@@ -2583,7 +2721,8 @@ def setParams(self, inputCol=None, outputCol=None):
@inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
+ JavaMLWritable):
"""
A feature transformer that merges multiple columns into a vector column.
@@ -2601,25 +2740,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
True
+ >>> dfWithNullsAndNaNs = spark.createDataFrame(
+ ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
+ >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
+ ... handleInvalid="keep")
+ >>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
+ +---+---+----+-------------+
+ | a| b| c| features|
+ +---+---+----+-------------+
+ |1.0|2.0|null|[1.0,2.0,NaN]|
+ |3.0|NaN| 4.0|[3.0,NaN,4.0]|
+ |5.0|6.0| 7.0|[5.0,6.0,7.0]|
+ +---+---+----+-------------+
+ ...
+ >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
+ +---+---+---+-------------+
+ | a| b| c| features|
+ +---+---+---+-------------+
+ |5.0|6.0|7.0|[5.0,6.0,7.0]|
+ +---+---+---+-------------+
+ ...
.. versionadded:: 1.4.0
"""
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
+ "and NaN values). Options are 'skip' (filter out rows with invalid " +
+ "data), 'error' (throw an error), or 'keep' (return relevant number " +
+ "of NaN in the output). Column lengths are taken from the size of ML " +
+ "Attribute Group, which can be set using `VectorSizeHint` in a " +
+ "pipeline before `VectorAssembler`. Column lengths can also be " +
+ "inferred from first rows of the data since it is safe to do so but " +
+ "only in case of 'error' or 'skip').",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
- def __init__(self, inputCols=None, outputCol=None):
+ def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
"""
- __init__(self, inputCols=None, outputCol=None)
+ __init__(self, inputCols=None, outputCol=None, handleInvalid="error")
"""
super(VectorAssembler, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
+ self._setDefault(handleInvalid="error")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, inputCols=None, outputCol=None):
+ def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
"""
- setParams(self, inputCols=None, outputCol=None)
+ setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
Sets params for this VectorAssembler.
"""
kwargs = self._input_kwargs
@@ -3717,4 +3887,4 @@ def setSize(self, value):
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index b8dafd49d354d..fd19fd96c4df6 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -16,8 +16,9 @@
#
from pyspark import keyword_only, since
+from pyspark.sql import DataFrame
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm
from pyspark.ml.param.shared import *
__all__ = ["FPGrowth", "FPGrowthModel"]
@@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items",
def _create_model(self, java_model):
return FPGrowthModel(java_model)
+
+
+class PrefixSpan(JavaParams):
+ """
+ .. note:: Experimental
+
+ A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ Efficiently by Prefix-Projected Pattern Growth
+ (see here).
+ This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns`
+ method to run the PrefixSpan algorithm.
+
+ @see Sequential Pattern Mining
+ (Wikipedia)
+ .. versionadded:: 2.4.0
+
+ """
+
+ minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " +
+ "sequential pattern. Sequential pattern that appears more than " +
+ "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.",
+ typeConverter=TypeConverters.toFloat)
+
+ maxPatternLength = Param(Params._dummy(), "maxPatternLength",
+ "The maximal length of the sequential pattern. Must be > 0.",
+ typeConverter=TypeConverters.toInt)
+
+ maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize",
+ "The maximum number of items (including delimiters used in the " +
+ "internal storage format) allowed in a projected database before " +
+ "local processing. If a projected database exceeds this size, " +
+ "another iteration of distributed prefix growth is run. " +
+ "Must be > 0.",
+ typeConverter=TypeConverters.toInt)
+
+ sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " +
+ "dataset, rows with nulls in this column are ignored.",
+ typeConverter=TypeConverters.toString)
+
+ @keyword_only
+ def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence"):
+ """
+ __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
+ sequenceCol="sequence")
+ """
+ super(PrefixSpan, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid)
+ self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence")
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.4.0")
+ def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000,
+ sequenceCol="sequence"):
+ """
+ setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
+ sequenceCol="sequence")
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+ @since("2.4.0")
+ def findFrequentSequentialPatterns(self, dataset):
+ """
+ .. note:: Experimental
+ Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
+
+ :param dataset: A dataframe containing a sequence column which is
+ `ArrayType(ArrayType(T))` type, T is the item type for the input dataset.
+ :return: A `DataFrame` that contains columns of sequence and corresponding frequency.
+ The schema of it will be:
+ - `sequence: ArrayType(ArrayType(T))` (T is the item type)
+ - `freq: Long`
+
+ >>> from pyspark.ml.fpm import PrefixSpan
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]),
+ ... Row(sequence=[[1], [3, 2], [1, 2]]),
+ ... Row(sequence=[[1, 2], [5]]),
+ ... Row(sequence=[[6]])]).toDF()
+ >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5)
+ >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False)
+ +----------+----+
+ |sequence |freq|
+ +----------+----+
+ |[[1]] |3 |
+ |[[1], [3]]|2 |
+ |[[1, 2]] |3 |
+ |[[2]] |3 |
+ |[[3]] |2 |
+ +----------+----+
+
+ .. versionadded:: 2.4.0
+ """
+ self._transfer_params_to_java()
+ jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
+ return DataFrame(jdf, dataset.sql_ctx)
diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py
index 45c936645f2a8..5f0c57ee3cc67 100644
--- a/python/pyspark/ml/image.py
+++ b/python/pyspark/ml/image.py
@@ -24,11 +24,15 @@
:members:
"""
+import sys
+
import numpy as np
from pyspark import SparkContext
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
from pyspark.sql import DataFrame, SparkSession
+__all__ = ["ImageSchema"]
+
class _ImageSchema(object):
"""
@@ -251,7 +255,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index ad1b487676fa7..6a611a2b5b59d 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -1158,7 +1158,7 @@ def _test():
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index db951d81de1e7..6e9e0a34cdfde 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -157,6 +157,11 @@ def get$Name(self):
"TypeConverters.toInt"),
("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
"1", "TypeConverters.toInt"),
+ ("collectSubModels", "Param for whether to collect a list of sub-models trained during " +
+ "tuning. If set to false, then only the single best sub-model will be available after " +
+ "fitting. If set to true, then all sub-models will be available. Warning: For large " +
+ "models, collecting all sub-models can cause OOMs on the Spark driver.",
+ "False", "TypeConverters.toBoolean"),
("loss", "the loss function to be optimized.", None, "TypeConverters.toString")]
code = []
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 474c38764e5a1..08408ee8fbfcc 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -655,6 +655,30 @@ def getParallelism(self):
return self.getOrDefault(self.parallelism)
+class HasCollectSubModels(Params):
+ """
+ Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
+ """
+
+ collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean)
+
+ def __init__(self):
+ super(HasCollectSubModels, self).__init__()
+ self._setDefault(collectSubModels=False)
+
+ def setCollectSubModels(self, value):
+ """
+ Sets the value of :py:attr:`collectSubModels`.
+ """
+ return self._set(collectSubModels=value)
+
+ def getCollectSubModels(self):
+ """
+ Gets the value of collectSubModels or its default value.
+ """
+ return self.getOrDefault(self.collectSubModels)
+
+
class HasLoss(Params):
"""
Mixin for param loss: the loss function to be optimized.
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index e8bcbe4cd34cb..a8eae9bd268d3 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+import sys
+
from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
@@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers):
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index f0812bd1d4a39..dba0e57b01a0b 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import sys
import warnings
from pyspark import since, keyword_only
@@ -335,10 +336,10 @@ def rootMeanSquaredError(self):
@since("2.0.0")
def r2(self):
"""
- Returns R^2^, the coefficient of determination.
+ Returns R^2, the coefficient of determination.
.. seealso:: `Wikipedia coefficient of determination \
- `
+ `_
.. note:: This ignores instance weights (setting all to 1.0) from
`LinearRegression.weightCol`. This will change in later Spark
@@ -346,6 +347,20 @@ def r2(self):
"""
return self._call_java("r2")
+ @property
+ @since("2.4.0")
+ def r2adj(self):
+ """
+ Returns Adjusted R^2, the adjusted coefficient of determination.
+
+ .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \
+ `_
+
+ .. note:: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark versions.
+ """
+ return self._call_java("r2adj")
+
@property
@since("2.0.0")
def residuals(self):
@@ -587,6 +602,14 @@ class TreeEnsembleParams(DecisionTreeParams):
"used for learning each decision tree, in range (0, 1].",
typeConverter=TypeConverters.toFloat)
+ supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
+
+ featureSubsetStrategy = \
+ Param(Params._dummy(), "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node. Supported " +
+ "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].",
+ typeConverter=TypeConverters.toString)
+
def __init__(self):
super(TreeEnsembleParams, self).__init__()
@@ -604,6 +627,22 @@ def getSubsamplingRate(self):
"""
return self.getOrDefault(self.subsamplingRate)
+ @since("1.4.0")
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+
+ .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0.
+ """
+ return self._set(featureSubsetStrategy=value)
+
+ @since("1.4.0")
+ def getFeatureSubsetStrategy(self):
+ """
+ Gets the value of featureSubsetStrategy or its default value.
+ """
+ return self.getOrDefault(self.featureSubsetStrategy)
+
class TreeRegressorParams(Params):
"""
@@ -639,14 +678,8 @@ class RandomForestParams(TreeEnsembleParams):
Private class to track supported random forest parameters.
"""
- supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
typeConverter=TypeConverters.toInt)
- featureSubsetStrategy = \
- Param(Params._dummy(), "featureSubsetStrategy",
- "The number of features to consider for splits at each tree node. Supported " +
- "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].",
- typeConverter=TypeConverters.toString)
def __init__(self):
super(RandomForestParams, self).__init__()
@@ -665,20 +698,6 @@ def getNumTrees(self):
"""
return self.getOrDefault(self.numTrees)
- @since("1.4.0")
- def setFeatureSubsetStrategy(self, value):
- """
- Sets the value of :py:attr:`featureSubsetStrategy`.
- """
- return self._set(featureSubsetStrategy=value)
-
- @since("1.4.0")
- def getFeatureSubsetStrategy(self):
- """
- Gets the value of featureSubsetStrategy or its default value.
- """
- return self.getOrDefault(self.featureSubsetStrategy)
-
class GBTParams(TreeEnsembleParams):
"""
@@ -966,6 +985,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)
+ @since("2.4.0")
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+ """
+ return self._set(featureSubsetStrategy=value)
+
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
@@ -1014,6 +1040,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> print(gbt.getImpurity())
variance
+ >>> print(gbt.getFeatureSubsetStrategy())
+ all
>>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
@@ -1041,6 +1069,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
+ ... ["label", "features"])
+ >>> model.evaluateEachIteration(validation, "squared")
+ [0.0, 0.0, 0.0, 0.0, 0.0]
.. versionadded:: 1.4.0
"""
@@ -1060,20 +1092,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
- impurity="variance"):
+ impurity="variance", featureSubsetStrategy="all"):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
- impurity="variance")
+ impurity="variance", featureSubsetStrategy="all")
"""
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
- impurity="variance")
+ impurity="variance", featureSubsetStrategy="all")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1083,13 +1115,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
- impuriy="variance"):
+ impuriy="variance", featureSubsetStrategy="all"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
- impurity="variance")
+ impurity="variance", featureSubsetStrategy="all")
Sets params for Gradient Boosted Tree Regression.
"""
kwargs = self._input_kwargs
@@ -1112,6 +1144,13 @@ def getLossType(self):
"""
return self.getOrDefault(self.lossType)
+ @since("2.4.0")
+ def setFeatureSubsetStrategy(self, value):
+ """
+ Sets the value of :py:attr:`featureSubsetStrategy`.
+ """
+ return self._set(featureSubsetStrategy=value)
+
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
@@ -1141,6 +1180,20 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset, loss):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ :param loss:
+ The loss function used to compute error.
+ Supported options: squared, absolute
+ """
+ return self._call_java("evaluateEachIteration", dataset, loss)
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
@@ -1812,4 +1865,4 @@ def __repr__(self):
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 079b0833e1c6d..a06ab31a7a56a 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -15,9 +15,13 @@
# limitations under the License.
#
+import sys
+
from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.wrapper import _jvm
+from pyspark.ml.wrapper import JavaWrapper, _jvm
+from pyspark.sql.column import Column, _to_seq
+from pyspark.sql.functions import lit
class ChiSquareTest(object):
@@ -30,32 +34,6 @@ class ChiSquareTest(object):
The null hypothesis is that the occurrence of the outcomes is statistically independent.
- :param dataset:
- DataFrame of categorical labels and categorical features.
- Real-valued features will be treated as categorical for each distinct value.
- :param featuresCol:
- Name of features column in dataset, of type `Vector` (`VectorUDT`).
- :param labelCol:
- Name of label column in dataset, of any numerical type.
- :return:
- DataFrame containing the test result for every feature against the label.
- This DataFrame will contain a single Row with the following fields:
- - `pValues: Vector`
- - `degreesOfFreedom: Array[Int]`
- - `statistics: Vector`
- Each of these fields has one value per feature.
-
- >>> from pyspark.ml.linalg import Vectors
- >>> from pyspark.ml.stat import ChiSquareTest
- >>> dataset = [[0, Vectors.dense([0, 0, 1])],
- ... [0, Vectors.dense([1, 0, 1])],
- ... [1, Vectors.dense([2, 1, 1])],
- ... [1, Vectors.dense([3, 1, 1])]]
- >>> dataset = spark.createDataFrame(dataset, ["label", "features"])
- >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
- >>> chiSqResult.select("degreesOfFreedom").collect()[0]
- Row(degreesOfFreedom=[3, 1, 0])
-
.. versionadded:: 2.2.0
"""
@@ -64,6 +42,32 @@ class ChiSquareTest(object):
def test(dataset, featuresCol, labelCol):
"""
Perform a Pearson's independence test using dataset.
+
+ :param dataset:
+ DataFrame of categorical labels and categorical features.
+ Real-valued features will be treated as categorical for each distinct value.
+ :param featuresCol:
+ Name of features column in dataset, of type `Vector` (`VectorUDT`).
+ :param labelCol:
+ Name of label column in dataset, of any numerical type.
+ :return:
+ DataFrame containing the test result for every feature against the label.
+ This DataFrame will contain a single Row with the following fields:
+ - `pValues: Vector`
+ - `degreesOfFreedom: Array[Int]`
+ - `statistics: Vector`
+ Each of these fields has one value per feature.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml.stat import ChiSquareTest
+ >>> dataset = [[0, Vectors.dense([0, 0, 1])],
+ ... [0, Vectors.dense([1, 0, 1])],
+ ... [1, Vectors.dense([2, 1, 1])],
+ ... [1, Vectors.dense([3, 1, 1])]]
+ >>> dataset = spark.createDataFrame(dataset, ["label", "features"])
+ >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
+ >>> chiSqResult.select("degreesOfFreedom").collect()[0]
+ Row(degreesOfFreedom=[3, 1, 0])
"""
sc = SparkContext._active_spark_context
javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest
@@ -83,40 +87,6 @@ class Correlation(object):
which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
to avoid recomputing the common lineage.
- :param dataset:
- A dataset or a dataframe.
- :param column:
- The name of the column of vectors for which the correlation coefficient needs
- to be computed. This must be a column of the dataset, and it must contain
- Vector objects.
- :param method:
- String specifying the method to use for computing correlation.
- Supported: `pearson` (default), `spearman`.
- :return:
- A dataframe that contains the correlation matrix of the column of vectors. This
- dataframe contains a single row and a single column of name
- '$METHODNAME($COLUMN)'.
-
- >>> from pyspark.ml.linalg import Vectors
- >>> from pyspark.ml.stat import Correlation
- >>> dataset = [[Vectors.dense([1, 0, 0, -2])],
- ... [Vectors.dense([4, 5, 0, 3])],
- ... [Vectors.dense([6, 7, 0, 8])],
- ... [Vectors.dense([9, 0, 0, 1])]]
- >>> dataset = spark.createDataFrame(dataset, ['features'])
- >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
- >>> print(str(pearsonCorr).replace('nan', 'NaN'))
- DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
- [ 0.0556..., 1. , NaN, 0.9135...],
- [ NaN, NaN, 1. , NaN],
- [ 0.4004..., 0.9135..., NaN, 1. ]])
- >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
- >>> print(str(spearmanCorr).replace('nan', 'NaN'))
- DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
- [ 0.1054..., 1. , NaN, 0.9486... ],
- [ NaN, NaN, 1. , NaN],
- [ 0.4 , 0.9486... , NaN, 1. ]])
-
.. versionadded:: 2.2.0
"""
@@ -125,6 +95,40 @@ class Correlation(object):
def corr(dataset, column, method="pearson"):
"""
Compute the correlation matrix with specified method using dataset.
+
+ :param dataset:
+ A Dataset or a DataFrame.
+ :param column:
+ The name of the column of vectors for which the correlation coefficient needs
+ to be computed. This must be a column of the dataset, and it must contain
+ Vector objects.
+ :param method:
+ String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`.
+ :return:
+ A DataFrame that contains the correlation matrix of the column of vectors. This
+ DataFrame contains a single row and a single column of name
+ '$METHODNAME($COLUMN)'.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml.stat import Correlation
+ >>> dataset = [[Vectors.dense([1, 0, 0, -2])],
+ ... [Vectors.dense([4, 5, 0, 3])],
+ ... [Vectors.dense([6, 7, 0, 8])],
+ ... [Vectors.dense([9, 0, 0, 1])]]
+ >>> dataset = spark.createDataFrame(dataset, ['features'])
+ >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
+ >>> print(str(pearsonCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
+ [ 0.0556..., 1. , NaN, 0.9135...],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4004..., 0.9135..., NaN, 1. ]])
+ >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
+ >>> print(str(spearmanCorr).replace('nan', 'NaN'))
+ DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
+ [ 0.1054..., 1. , NaN, 0.9486... ],
+ [ NaN, NaN, 1. , NaN],
+ [ 0.4 , 0.9486... , NaN, 1. ]])
"""
sc = SparkContext._active_spark_context
javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
@@ -132,6 +136,256 @@ def corr(dataset, column, method="pearson"):
return _java2py(sc, javaCorrObj.corr(*args))
+class KolmogorovSmirnovTest(object):
+ """
+ .. note:: Experimental
+
+ Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous
+ distribution.
+
+ By comparing the largest difference between the empirical cumulative
+ distribution of the sample data and the theoretical distribution we can provide a test for the
+ the null hypothesis that the sample data comes from that theoretical distribution.
+
+ .. versionadded:: 2.4.0
+
+ """
+ @staticmethod
+ @since("2.4.0")
+ def test(dataset, sampleCol, distName, *params):
+ """
+ Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution
+ equality. Currently supports the normal distribution, taking as parameters the mean and
+ standard deviation.
+
+ :param dataset:
+ a Dataset or a DataFrame containing the sample of data to test.
+ :param sampleCol:
+ Name of sample column in dataset, of any numerical type.
+ :param distName:
+ a `string` name for a theoretical distribution, currently only support "norm".
+ :param params:
+ a list of `Double` values specifying the parameters to be used for the theoretical
+ distribution. For "norm" distribution, the parameters includes mean and variance.
+ :return:
+ A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data.
+ This DataFrame will contain a single Row with the following fields:
+ - `pValue: Double`
+ - `statistic: Double`
+
+ >>> from pyspark.ml.stat import KolmogorovSmirnovTest
+ >>> dataset = [[-1.0], [0.0], [1.0]]
+ >>> dataset = spark.createDataFrame(dataset, ['sample'])
+ >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first()
+ >>> round(ksResult.pValue, 3)
+ 1.0
+ >>> round(ksResult.statistic, 3)
+ 0.175
+ >>> dataset = [[2.0], [3.0], [4.0]]
+ >>> dataset = spark.createDataFrame(dataset, ['sample'])
+ >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first()
+ >>> round(ksResult.pValue, 3)
+ 1.0
+ >>> round(ksResult.statistic, 3)
+ 0.175
+ """
+ sc = SparkContext._active_spark_context
+ javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest
+ dataset = _py2java(sc, dataset)
+ params = [float(param) for param in params]
+ return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName,
+ _jvm().PythonUtils.toSeq(params)))
+
+
+class Summarizer(object):
+ """
+ .. note:: Experimental
+
+ Tools for vectorized statistics on MLlib Vectors.
+ The methods in this package provide various statistics for Vectors contained inside DataFrames.
+ This class lets users pick the statistics they would like to extract for a given column.
+
+ >>> from pyspark.ml.stat import Summarizer
+ >>> from pyspark.sql import Row
+ >>> from pyspark.ml.linalg import Vectors
+ >>> summarizer = Summarizer.metrics("mean", "count")
+ >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
+ ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
+ >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+ +-----------------------------------+
+ |aggregate_metrics(features, weight)|
+ +-----------------------------------+
+ |[[1.0,1.0,1.0], 1] |
+ +-----------------------------------+
+
+ >>> df.select(summarizer.summary(df.features)).show(truncate=False)
+ +--------------------------------+
+ |aggregate_metrics(features, 1.0)|
+ +--------------------------------+
+ |[[1.0,1.5,2.0], 2] |
+ +--------------------------------+
+
+ >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.0,1.0] |
+ +--------------+
+
+ >>> df.select(Summarizer.mean(df.features)).show(truncate=False)
+ +--------------+
+ |mean(features)|
+ +--------------+
+ |[1.0,1.5,2.0] |
+ +--------------+
+
+
+ .. versionadded:: 2.4.0
+
+ """
+ @staticmethod
+ @since("2.4.0")
+ def mean(col, weightCol=None):
+ """
+ return a column of mean summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "mean")
+
+ @staticmethod
+ @since("2.4.0")
+ def variance(col, weightCol=None):
+ """
+ return a column of variance summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "variance")
+
+ @staticmethod
+ @since("2.4.0")
+ def count(col, weightCol=None):
+ """
+ return a column of count summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "count")
+
+ @staticmethod
+ @since("2.4.0")
+ def numNonZeros(col, weightCol=None):
+ """
+ return a column of numNonZero summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
+
+ @staticmethod
+ @since("2.4.0")
+ def max(col, weightCol=None):
+ """
+ return a column of max summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "max")
+
+ @staticmethod
+ @since("2.4.0")
+ def min(col, weightCol=None):
+ """
+ return a column of min summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "min")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL1(col, weightCol=None):
+ """
+ return a column of normL1 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL1")
+
+ @staticmethod
+ @since("2.4.0")
+ def normL2(col, weightCol=None):
+ """
+ return a column of normL2 summary
+ """
+ return Summarizer._get_single_metric(col, weightCol, "normL2")
+
+ @staticmethod
+ def _check_param(featuresCol, weightCol):
+ if weightCol is None:
+ weightCol = lit(1.0)
+ if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column):
+ raise TypeError("featureCol and weightCol should be a Column")
+ return featuresCol, weightCol
+
+ @staticmethod
+ def _get_single_metric(col, weightCol, metric):
+ col, weightCol = Summarizer._check_param(col, weightCol)
+ return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric,
+ col._jc, weightCol._jc))
+
+ @staticmethod
+ @since("2.4.0")
+ def metrics(*metrics):
+ """
+ Given a list of metrics, provides a builder that it turns computes metrics from a column.
+
+ See the documentation of [[Summarizer]] for an example.
+
+ The following metrics are accepted (case sensitive):
+ - mean: a vector that contains the coefficient-wise mean.
+ - variance: a vector tha contains the coefficient-wise variance.
+ - count: the count of all vectors seen.
+ - numNonzeros: a vector with the number of non-zeros for each coefficients
+ - max: the maximum for each coefficient.
+ - min: the minimum for each coefficient.
+ - normL2: the Euclidian norm for each coefficient.
+ - normL1: the L1 norm of each coefficient (sum of the absolute values).
+
+ :param metrics:
+ metrics that can be provided.
+ :return:
+ an object of :py:class:`pyspark.ml.stat.SummaryBuilder`
+
+ Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
+ interface.
+ """
+ sc = SparkContext._active_spark_context
+ js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
+ _to_seq(sc, metrics))
+ return SummaryBuilder(js)
+
+
+class SummaryBuilder(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ A builder object that provides summary statistics about a given column.
+
+ Users should not directly create such builders, but instead use one of the methods in
+ :py:class:`pyspark.ml.stat.Summarizer`
+
+ .. versionadded:: 2.4.0
+
+ """
+ def __init__(self, jSummaryBuilder):
+ super(SummaryBuilder, self).__init__(jSummaryBuilder)
+
+ @since("2.4.0")
+ def summary(self, featuresCol, weightCol=None):
+ """
+ Returns an aggregate object that contains the summary of the column with the requested
+ metrics.
+
+ :param featuresCol:
+ a column that contains features Vector object.
+ :param weightCol:
+ a column that contains weight value. Default weight is 1.0.
+ :return:
+ an aggregate column that contains the statistics. The exact content of this
+ structure is determined during the creation of the builder.
+ """
+ featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol)
+ return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.stat
@@ -151,4 +405,4 @@ def corr(dataset, column, method="pearson"):
failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 75d04785a0710..ebd36cbb5f7a7 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -51,7 +51,7 @@
from pyspark.ml.classification import *
from pyspark.ml.clustering import *
from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.evaluation import BinaryClassificationEvaluator, \
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \
MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
@@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake):
pass
+class JavaWrapperMemoryTests(SparkSessionTestCase):
+
+ def test_java_object_gets_detached(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
+ fitIntercept=False)
+
+ model = lr.fit(df)
+ summary = model.summary
+
+ self.assertIsInstance(model, JavaWrapper)
+ self.assertIsInstance(summary, JavaWrapper)
+ self.assertIsInstance(model, JavaParams)
+ self.assertNotIsInstance(summary, JavaParams)
+
+ error_no_object = 'Target Object ID does not exist for this gateway'
+
+ self.assertIn("LinearRegression_", model._java_obj.toString())
+ self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
+
+ model.__del__()
+
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ model._java_obj.toString()
+ self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
+
+ try:
+ summary.__del__()
+ except:
+ pass
+
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ model._java_obj.toString()
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ summary._java_obj.toString()
+
+
class ParamTypeConversionTests(PySparkTestCase):
"""
Test that param type conversion happens.
@@ -330,7 +369,7 @@ def test_property(self):
raise RuntimeError("Test property to raise error when invoked")
-class ParamTests(PySparkTestCase):
+class ParamTests(SparkSessionTestCase):
def test_copy_new_parent(self):
testParams = TestParams()
@@ -418,6 +457,9 @@ def test_kmeans_param(self):
self.assertEqual(algo.getK(), 10)
algo.setInitSteps(10)
self.assertEqual(algo.getInitSteps(), 10)
+ self.assertEqual(algo.getDistanceMeasure(), "euclidean")
+ algo.setDistanceMeasure("cosine")
+ self.assertEqual(algo.getDistanceMeasure(), "cosine")
def test_hasseed(self):
noSeedSpecd = TestParams()
@@ -472,6 +514,24 @@ def test_logistic_regression_check_thresholds(self):
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
)
+ def test_preserve_set_state(self):
+ dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+ binarizer = Binarizer(inputCol="data")
+ self.assertFalse(binarizer.isSet("threshold"))
+ binarizer.transform(dataset)
+ binarizer._transfer_params_from_java()
+ self.assertFalse(binarizer.isSet("threshold"),
+ "Params not explicitly set should remain unset after transform")
+
+ def test_default_params_transferred(self):
+ dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+ binarizer = Binarizer(inputCol="data")
+ # intentionally change the pyspark default, but don't set it
+ binarizer._defaultParamMap[binarizer.outputCol] = "my_default"
+ result = binarizer.transform(dataset).select("my_default").collect()
+ self.assertFalse(binarizer.isSet(binarizer.outputCol))
+ self.assertEqual(result[0][0], 1.0)
+
@staticmethod
def check_params(test_self, py_stage, check_params_exist=True):
"""
@@ -538,6 +598,15 @@ def test_java_params(self):
self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
+ def test_clustering_evaluator_with_cosine_distance(self):
+ featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
+ [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0),
+ ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)])
+ dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
+ evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine")
+ self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
+ self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5))
+
class FeatureTests(SparkSessionTestCase):
@@ -612,6 +681,13 @@ def test_stopwordsremover(self):
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
+ # with locale
+ stopwords = ["BELKİ"]
+ dataset = self.spark.createDataFrame([Row(input=["belki"])])
+ stopWordRemover.setStopWords(stopwords).setLocale("tr")
+ self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, [])
def test_count_vectorizer_with_binary(self):
dataset = self.spark.createDataFrame([
@@ -628,6 +704,59 @@ def test_count_vectorizer_with_binary(self):
feature, expected = r
self.assertEqual(feature, expected)
+ def test_count_vectorizer_with_maxDF(self):
+ dataset = self.spark.createDataFrame([
+ (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0}),),
+ (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(inputCol="words", outputCol="features")
+ model1 = cv.setMaxDF(3).fit(dataset)
+ self.assertEqual(model1.vocabulary, ['b', 'c', 'd'])
+
+ transformedList1 = model1.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList1:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ model2 = cv.setMaxDF(0.75).fit(dataset)
+ self.assertEqual(model2.vocabulary, ['b', 'c', 'd'])
+
+ transformedList2 = model2.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList2:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ def test_count_vectorizer_from_vocab(self):
+ model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
+ outputCol="features", minTF=2)
+ self.assertEqual(model.vocabulary, ["a", "b", "c"])
+ self.assertEqual(model.getMinTF(), 2)
+
+ dataset = self.spark.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
+ (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
+
+ transformed_list = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformed_list:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ # Test an empty vocabulary
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
+ CountVectorizerModel.from_vocabulary([], inputCol="words")
+
+ # Test model with default settings can transform
+ model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
+ transformed_list = model_default.transform(dataset)\
+ .select(model_default.getOrDefault(model_default.outputCol)).collect()
+ self.assertEqual(len(transformed_list), 3)
+
def test_rformula_force_index_label(self):
df = self.spark.createDataFrame([
(1.0, 1.0, "a"),
@@ -678,6 +807,43 @@ def test_string_indexer_handle_invalid(self):
expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
self.assertEqual(actual2, expected2)
+ def test_string_indexer_from_labels(self):
+ model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
+ outputCol="indexed", handleInvalid="keep")
+ self.assertEqual(model.labels, ["a", "b", "c"])
+
+ df1 = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "c"),
+ (2, None),
+ (3, "b"),
+ (4, "b")], ["id", "label"])
+
+ result1 = model.transform(df1)
+ actual1 = result1.select("id", "indexed").collect()
+ expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0),
+ Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
+ self.assertEqual(actual1, expected1)
+
+ model_empty_labels = StringIndexerModel.from_labels(
+ [], inputCol="label", outputCol="indexed", handleInvalid="keep")
+ actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect()
+ expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0),
+ Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
+ self.assertEqual(actual2, expected2)
+
+ # Test model with default settings can transform
+ model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label")
+ df2 = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "c"),
+ (2, "b"),
+ (3, "b"),
+ (4, "b")], ["id", "label"])
+ transformed_list = model_default.transform(df2)\
+ .select(model_default.getOrDefault(model_default.outputCol)).collect()
+ self.assertEqual(len(transformed_list), 5)
+
class HasInducedError(Params):
@@ -859,6 +1025,50 @@ def test_parallel_evaluation(self):
cvParallelModel = cv.fit(dataset)
self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ numFolds = 3
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ numFolds=numFolds, collectSubModels=True)
+
+ def checkSubModels(subModels):
+ self.assertEqual(len(subModels), numFolds)
+ for i in range(numFolds):
+ self.assertEqual(len(subModels[i]), len(grid))
+
+ cvModel = cv.fit(dataset)
+ checkSubModels(cvModel.subModels)
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testCrossValidatorSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ cvModel.save(savingPathWithSubModels)
+ cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+ checkSubModels(cvModel3.subModels)
+ cvModel4 = cvModel3.copy()
+ checkSubModels(cvModel4.subModels)
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+ self.assertEqual(cvModel2.subModels, None)
+
+ for i in range(numFolds):
+ for j in range(len(grid)):
+ self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
+
def test_save_load_nested_estimator(self):
temp_path = tempfile.mkdtemp()
dataset = self.spark.createDataFrame(
@@ -1027,6 +1237,40 @@ def test_parallel_evaluation(self):
tvsParallelModel = tvs.fit(dataset)
self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics)
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ collectSubModels=True)
+ tvsModel = tvs.fit(dataset)
+ self.assertEqual(len(tvsModel.subModels), len(grid))
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testTrainValidationSplitSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ tvsModel.save(savingPathWithSubModels)
+ tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+ self.assertEqual(len(tvsModel3.subModels), len(grid))
+ tvsModel4 = tvsModel3.copy()
+ self.assertEqual(len(tvsModel4.subModels), len(grid))
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+ self.assertEqual(tvsModel2.subModels, None)
+
+ for i in range(len(grid)):
+ self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
+
def test_save_load_nested_estimator(self):
# This tests saving and loading the trained model only.
# Save/load for TrainValidationSplit will be added later: SPARK-13786
@@ -1358,6 +1602,44 @@ def test_default_read_write(self):
self.assertEqual(lr.uid, lr3.uid)
self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+ def test_default_read_write_default_params(self):
+ lr = LogisticRegression()
+ self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+ lr.setMaxIter(50)
+ lr.setThreshold(.75)
+
+ # `threshold` is set by user, default param `predictionCol` is not set by user.
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ writer = DefaultParamsWriter(lr)
+ metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+ self.assertTrue("defaultParamMap" in metadata)
+
+ reader = DefaultParamsReadable.read()
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ # manually create metadata without `defaultParamMap` section.
+ del metadata['defaultParamMap']
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+ metadata['sparkVersion'] = '2.3.0'
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
class LDATest(SparkSessionTestCase):
@@ -1437,6 +1719,7 @@ def test_linear_regression_summary(self):
self.assertAlmostEqual(s.meanSquaredError, 0.0)
self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
self.assertAlmostEqual(s.r2, 1.0, 2)
+ self.assertAlmostEqual(s.r2adj, 1.0, 2)
self.assertTrue(isinstance(s.residuals, DataFrame))
self.assertEqual(s.numInstances, 2)
self.assertEqual(s.degreesOfFreedom, 1)
@@ -1620,6 +1903,21 @@ def test_kmeans_summary(self):
self.assertEqual(s.k, 2)
+class KMeansTests(SparkSessionTestCase):
+
+ def test_kmeans_cosine_distance(self):
+ data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),),
+ (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),),
+ (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine")
+ model = kmeans.fit(df)
+ result = model.transform(df).collect()
+ self.assertTrue(result[0].prediction == result[1].prediction)
+ self.assertTrue(result[2].prediction == result[3].prediction)
+ self.assertTrue(result[4].prediction == result[5].prediction)
+
+
class OneVsRestTests(SparkSessionTestCase):
def test_copy(self):
@@ -1883,17 +2181,23 @@ class ImageReaderTest2(PySparkTestCase):
@classmethod
def setUpClass(cls):
super(ImageReaderTest2, cls).setUpClass()
+ cls.hive_available = True
# Note that here we enable Hive's support.
cls.spark = None
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
cls.tearDownClass()
- raise unittest.SkipTest("Hive is not available")
+ cls.hive_available = False
except TypeError:
cls.tearDownClass()
- raise unittest.SkipTest("Hive is not available")
- cls.spark = HiveContext._createForTesting(cls.sc)
+ cls.hive_available = False
+ if cls.hive_available:
+ cls.spark = HiveContext._createForTesting(cls.sc)
+
+ def setUp(self):
+ if not self.hive_available:
+ self.skipTest("Hive is not available.")
@classmethod
def tearDownClass(cls):
@@ -1943,18 +2247,28 @@ def test_java_params(self):
import pyspark.ml.feature
import pyspark.ml.classification
import pyspark.ml.clustering
+ import pyspark.ml.evaluation
import pyspark.ml.pipeline
import pyspark.ml.recommendation
import pyspark.ml.regression
+
modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering,
- pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression]
+ pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation,
+ pyspark.ml.regression]
for module in modules:
for name, cls in inspect.getmembers(module, inspect.isclass):
- if not name.endswith('Model') and issubclass(cls, JavaParams)\
- and not inspect.isabstract(cls):
+ if not name.endswith('Model') and not name.endswith('Params')\
+ and issubclass(cls, JavaParams) and not inspect.isabstract(cls):
# NOTE: disable check_params_exist until there is parity with Scala API
ParamTests.check_params(self, cls(), check_params_exist=False)
+ # Additional classes that need explicit construction
+ from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
+ ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
+ check_params_exist=False)
+ ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'),
+ check_params_exist=False)
+
def _squared_distance(a, b):
if isinstance(a, Vector):
@@ -2399,6 +2713,6 @@ def testDefaultFitMultiple(self):
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
else:
- unittest.main()
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 6c0cad6cbaaa1..0c8029f293cfe 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -15,14 +15,16 @@
# limitations under the License.
#
import itertools
-import numpy as np
+import sys
from multiprocessing.pool import ThreadPool
+import numpy as np
+
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.param import Params, Param, TypeConverters
-from pyspark.ml.param.shared import HasParallelism, HasSeed
+from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
@@ -31,7 +33,7 @@
'TrainValidationSplitModel']
-def _parallelFitTasks(est, train, eva, validation, epm):
+def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
"""
Creates a list of callables which can be called from different threads to fit and evaluate
an estimator in parallel. Each callable returns an `(index, metric)` pair.
@@ -41,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm):
:param eva: Evaluator, used to compute `metric`
:param validation: DataFrame, validation data set, used for evaluation.
:param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
- :return: (int, float), an index into `epm` and the associated metric value.
+ :param collectSubModel: Whether to collect sub model.
+ :return: (int, float, subModel), an index into `epm` and the associated metric value.
"""
modelIter = est.fitMultiple(train, epm)
def singleTask():
index, model = next(modelIter)
metric = eva.evaluate(model.transform(validation, epm[index]))
- return index, metric
+ return index, metric, model if collectSubModel else None
return [singleTask] * len(epm)
@@ -192,7 +195,8 @@ def _to_java_impl(self):
return java_estimator, java_epms, java_evaluator
-class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable):
+class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels,
+ MLReadable, MLWritable):
"""
K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -231,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1)
+ seed=None, parallelism=1, collectSubModels=False)
"""
super(CrossValidator, self).__init__()
self._setDefault(numFolds=3, parallelism=1)
@@ -244,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF
@keyword_only
@since("1.4.0")
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1):
+ seed=None, parallelism=1, collectSubModels=False):
Sets params for cross validator.
"""
kwargs = self._input_kwargs
@@ -280,6 +284,10 @@ def _fit(self, dataset):
metrics = [0.0] * numModels
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
+ subModels = None
+ collectSubModelsParam = self.getCollectSubModels()
+ if collectSubModelsParam:
+ subModels = [[None for j in range(numModels)] for i in range(nFolds)]
for i in range(nFolds):
validateLB = i * h
@@ -288,9 +296,12 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- tasks = _parallelFitTasks(est, train, eva, validation, epm)
- for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] += (metric / nFolds)
+ if collectSubModelsParam:
+ subModels[i][j] = subModel
+
validation.unpersist()
train.unpersist()
@@ -299,7 +310,7 @@ def _fit(self, dataset):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(CrossValidatorModel(bestModel, metrics))
+ return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
@since("1.4.0")
def copy(self, extra=None):
@@ -343,9 +354,11 @@ def _from_java(cls, java_stage):
numFolds = java_stage.getNumFolds()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
+ collectSubModels = java_stage.getCollectSubModels()
# Create a new instance of this stage.
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- numFolds=numFolds, seed=seed, parallelism=parallelism)
+ numFolds=numFolds, seed=seed, parallelism=parallelism,
+ collectSubModels=collectSubModels)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -365,6 +378,7 @@ def _to_java(self):
_java_obj.setSeed(self.getSeed())
_java_obj.setNumFolds(self.getNumFolds())
_java_obj.setParallelism(self.getParallelism())
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
return _java_obj
@@ -379,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
.. versionadded:: 1.4.0
"""
- def __init__(self, bestModel, avgMetrics=[]):
+ def __init__(self, bestModel, avgMetrics=[], subModels=None):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
#: CrossValidator.estimatorParamMaps, in the corresponding order.
self.avgMetrics = avgMetrics
+ #: sub model list from cross validation
+ self.subModels = subModels
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -397,6 +413,7 @@ def copy(self, extra=None):
and some extra params. This copies the underlying bestModel,
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
+ It does not copy the extra Params into the subModels.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
@@ -405,7 +422,8 @@ def copy(self, extra=None):
extra = dict()
bestModel = self.bestModel.copy(extra)
avgMetrics = self.avgMetrics
- return CrossValidatorModel(bestModel, avgMetrics)
+ subModels = self.subModels
+ return CrossValidatorModel(bestModel, avgMetrics, subModels)
@since("2.3.0")
def write(self):
@@ -424,13 +442,17 @@ def _from_java(cls, java_stage):
Given a Java CrossValidatorModel, create and return a Python wrapper of it.
Used for ML persistence.
"""
-
bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ if java_stage.hasSubModels():
+ py_stage.subModels = [[JavaParams._from_java(sub_model)
+ for sub_model in fold_sub_models]
+ for fold_sub_models in java_stage.subModels()]
+
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -452,10 +474,16 @@ def _to_java(self):
_java_obj.set("evaluator", evaluator)
_java_obj.set("estimator", estimator)
_java_obj.set("estimatorParamMaps", epms)
+
+ if self.subModels is not None:
+ java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models]
+ for fold_sub_models in self.subModels]
+ _java_obj.setSubModels(java_sub_models)
return _java_obj
-class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable):
+class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels,
+ MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -490,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, seed=None)
+ parallelism=1, collectSubModels=False, seed=None)
"""
super(TrainValidationSplit, self).__init__()
self._setDefault(trainRatio=0.75, parallelism=1)
@@ -503,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai
@since("2.0.0")
@keyword_only
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, seed=None):
+ parallelism=1, collectSubModels=False, seed=None):
Sets params for the train validation split.
"""
kwargs = self._input_kwargs
@@ -539,11 +567,19 @@ def _fit(self, dataset):
validation = df.filter(condition).cache()
train = df.filter(~condition).cache()
- tasks = _parallelFitTasks(est, train, eva, validation, epm)
+ subModels = None
+ collectSubModelsParam = self.getCollectSubModels()
+ if collectSubModelsParam:
+ subModels = [None for i in range(numModels)]
+
+ tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = [None] * numModels
- for j, metric in pool.imap_unordered(lambda f: f(), tasks):
+ for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] = metric
+ if collectSubModelsParam:
+ subModels[j] = subModel
+
train.unpersist()
validation.unpersist()
@@ -552,7 +588,7 @@ def _fit(self, dataset):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(TrainValidationSplitModel(bestModel, metrics))
+ return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
@since("2.0.0")
def copy(self, extra=None):
@@ -596,9 +632,11 @@ def _from_java(cls, java_stage):
trainRatio = java_stage.getTrainRatio()
seed = java_stage.getSeed()
parallelism = java_stage.getParallelism()
+ collectSubModels = java_stage.getCollectSubModels()
# Create a new instance of this stage.
py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- trainRatio=trainRatio, seed=seed, parallelism=parallelism)
+ trainRatio=trainRatio, seed=seed, parallelism=parallelism,
+ collectSubModels=collectSubModels)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -618,7 +656,7 @@ def _to_java(self):
_java_obj.setTrainRatio(self.getTrainRatio())
_java_obj.setSeed(self.getSeed())
_java_obj.setParallelism(self.getParallelism())
-
+ _java_obj.setCollectSubModels(self.getCollectSubModels())
return _java_obj
@@ -631,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
.. versionadded:: 2.0.0
"""
- def __init__(self, bestModel, validationMetrics=[]):
+ def __init__(self, bestModel, validationMetrics=[], subModels=None):
super(TrainValidationSplitModel, self).__init__()
- #: best model from cross validation
+ #: best model from train validation split
self.bestModel = bestModel
#: evaluated validation metrics
self.validationMetrics = validationMetrics
+ #: sub models from train validation split
+ self.subModels = subModels
def _transform(self, dataset):
return self.bestModel.transform(dataset)
@@ -649,6 +689,7 @@ def copy(self, extra=None):
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
And, this creates a shallow copy of the validationMetrics.
+ It does not copy the extra Params into the subModels.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
@@ -657,7 +698,8 @@ def copy(self, extra=None):
extra = dict()
bestModel = self.bestModel.copy(extra)
validationMetrics = list(self.validationMetrics)
- return TrainValidationSplitModel(bestModel, validationMetrics)
+ subModels = self.subModels
+ return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
@since("2.3.0")
def write(self):
@@ -685,6 +727,10 @@ def _from_java(cls, java_stage):
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ if java_stage.hasSubModels():
+ py_stage.subModels = [JavaParams._from_java(sub_model)
+ for sub_model in java_stage.subModels()]
+
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -706,6 +752,11 @@ def _to_java(self):
_java_obj.set("evaluator", evaluator)
_java_obj.set("estimator", estimator)
_java_obj.set("estimatorParamMaps", epms)
+
+ if self.subModels is not None:
+ java_sub_models = [sub_model._to_java() for sub_model in self.subModels]
+ _java_obj.setSubModels(java_sub_models)
+
return _java_obj
@@ -727,4 +778,4 @@ def _to_java(self):
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index c3c47bd79459a..9fa85664939b8 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -30,6 +30,7 @@
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
+from pyspark.util import VersionUtils
def _jvm():
@@ -169,6 +170,10 @@ def overwrite(self):
self._jwrite.overwrite()
return self
+ def option(self, key, value):
+ self._jwrite.option(key, value)
+ return self
+
def context(self, sqlContext):
"""
Sets the SQL context to use for saving.
@@ -392,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
- sparkVersion
- uid
- paramMap
+ - defaultParamMap (since 2.4.0)
- (optionally, extra metadata)
:param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
:param paramMap: If given, this is saved in the "paramMap" field.
@@ -413,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
"""
uid = instance.uid
cls = instance.__module__ + '.' + instance.__class__.__name__
- params = instance.extractParamMap()
+
+ # User-supplied param values
+ params = instance._paramMap
jsonParams = {}
if paramMap is not None:
jsonParams = paramMap
else:
for p in params:
jsonParams[p.name] = params[p]
+
+ # Default param values
+ jsonDefaultParams = {}
+ for p in instance._defaultParamMap:
+ jsonDefaultParams[p.name] = instance._defaultParamMap[p]
+
basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
- "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams}
+ "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
+ "defaultParamMap": jsonDefaultParams}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
return json.dumps(basicMetadata, separators=[',', ':'])
@@ -519,11 +534,26 @@ def getAndSetParams(instance, metadata):
"""
Extract Params from metadata, and set them in the instance.
"""
+ # Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
+ # Set default param values
+ majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
+ major = majorAndMinorVersions[0]
+ minor = majorAndMinorVersions[1]
+
+ # For metadata file prior to Spark 2.4, there is no default section.
+ if major > 2 or (major == 2 and minor >= 4):
+ assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
+ "`defaultParamMap` section not found"
+
+ for paramName in metadata['defaultParamMap']:
+ paramValue = metadata['defaultParamMap'][paramName]
+ instance._setDefault(**{paramName: paramValue})
+
@staticmethod
def loadParamsInstance(path, sc):
"""
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 0f846fbc5b5ef..d325633195ddb 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -36,6 +36,10 @@ def __init__(self, java_obj=None):
super(JavaWrapper, self).__init__()
self._java_obj = java_obj
+ def __del__(self):
+ if SparkContext._active_spark_context and self._java_obj is not None:
+ SparkContext._active_spark_context._gateway.detach(self._java_obj)
+
@classmethod
def _create_from_java_class(cls, java_class, *args):
"""
@@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params):
__metaclass__ = ABCMeta
- def __del__(self):
- if SparkContext._active_spark_context:
- SparkContext._active_spark_context._gateway.detach(self._java_obj)
-
def _make_java_param_pair(self, param, value):
"""
Makes a Java param pair.
@@ -118,11 +118,18 @@ def _transfer_params_to_java(self):
"""
Transforms the embedded params to the companion Java object.
"""
- paramMap = self.extractParamMap()
+ pair_defaults = []
for param in self.params:
- if param in paramMap:
- pair = self._make_java_param_pair(param, paramMap[param])
+ if self.isSet(param):
+ pair = self._make_java_param_pair(param, self._paramMap[param])
self._java_obj.set(pair)
+ if self.hasDefault(param):
+ pair = self._make_java_param_pair(param, self._defaultParamMap[param])
+ pair_defaults.append(pair)
+ if len(pair_defaults) > 0:
+ sc = SparkContext._active_spark_context
+ pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
+ self._java_obj.setDefault(pair_defaults_seq)
def _transfer_param_map_to_java(self, pyParamMap):
"""
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index cce703d432b5a..bb281981fd56b 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -16,6 +16,7 @@
#
from math import exp
+import sys
import warnings
import numpy
@@ -761,7 +762,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index bb687a7da6ffd..0cbabab13a896 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -1048,7 +1048,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 2cd1da3fbf9aa..36cb03369b8c0 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import sys
import warnings
from pyspark import since
@@ -542,7 +543,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index e5231dc3a27a8..40ecd2e0ff4be 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -819,7 +819,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
sys.path.pop(0)
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index f58ea5dfb0874..de18dad1f675d 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+import sys
+
import numpy
from numpy import array
from collections import namedtuple
@@ -197,7 +199,7 @@ def _test():
except OSError:
pass
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 7b24b3c74a9fa..60d96d8d5ceb8 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -1370,7 +1370,7 @@ def _test():
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index 4cb802514be52..bba88542167ad 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -1377,7 +1377,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 61213ddf62e8b..a8833cb446923 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -19,6 +19,7 @@
Python package for random data generation.
"""
+import sys
from functools import wraps
from pyspark import since
@@ -421,7 +422,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 81182881352bb..3d4eae85132bb 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -16,6 +16,7 @@
#
import array
+import sys
from collections import namedtuple
from pyspark import SparkContext, since
@@ -326,7 +327,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index ea107d400621d..6be45f51862c9 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -15,9 +15,11 @@
# limitations under the License.
#
+import sys
+import warnings
+
import numpy as np
from numpy import array
-import warnings
from pyspark import RDD, since
from pyspark.streaming.dstream import DStream
@@ -837,7 +839,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
index 49b26446dbc32..3c75b132ecad2 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -313,7 +313,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1037bab7f1088..4c2ce137e331c 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -57,6 +57,7 @@
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.fpm import FPGrowth
from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
@@ -1762,14 +1763,25 @@ def test_pca(self):
self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1])
+class FPGrowthTest(MLlibTestCase):
+
+ def test_fpgrowth(self):
+ data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
+ rdd = self.sc.parallelize(data, 2)
+ model1 = FPGrowth.train(rdd, 0.6, 2)
+ # use default data partition number when numPartitions is not specified
+ model2 = FPGrowth.train(rdd, 0.6)
+ self.assertEqual(sorted(model1.freqItemsets().collect()),
+ sorted(model2.freqItemsets().collect()))
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
else:
- unittest.main()
+ unittest.main(verbosity=2)
if not _have_scipy:
print("NOTE: SciPy tests were skipped as it does not seem to be installed")
sc.stop()
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 619fa16d463f5..b05734ce489d9 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -17,6 +17,7 @@
from __future__ import absolute_import
+import sys
import random
from pyspark import SparkContext, RDD, since
@@ -654,7 +655,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 97755807ef262..fc7809387b13a 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -521,7 +521,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
index 44d17bd629473..3c7656ab5758c 100644
--- a/python/pyspark/profiler.py
+++ b/python/pyspark/profiler.py
@@ -19,6 +19,7 @@
import pstats
import os
import atexit
+import sys
from pyspark.accumulators import AccumulatorParam
@@ -173,4 +174,4 @@ def stats(self):
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 93b8974a7e64a..7e7e5822a6b20 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,9 +39,11 @@
else:
from itertools import imap as map, ifilter as filter
+from pyspark.java_gateway import do_server_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
- PickleSerializer, pack_long, AutoBatchedSerializer
+ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
+ UTF8Deserializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -51,6 +53,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.util import fail_on_stopiteration
__all__ = ["RDD"]
@@ -71,6 +74,7 @@ class PythonEvalType(object):
SQL_SCALAR_PANDAS_UDF = 200
SQL_GROUPED_MAP_PANDAS_UDF = 201
SQL_GROUPED_AGG_PANDAS_UDF = 202
+ SQL_WINDOW_AGG_PANDAS_UDF = 203
def portable_hash(x):
@@ -136,7 +140,8 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
-def _load_from_socket(port, serializer):
+def _load_from_socket(sock_info, serializer):
+ port, auth_secret = sock_info
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
@@ -156,8 +161,12 @@ def _load_from_socket(port, serializer):
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
+
+ sockfile = sock.makefile("rwb", 65536)
+ do_server_auth(sockfile, auth_secret)
+
# The socket will be automatically closed when garbage-collected.
- return serializer.load_stream(sock.makefile("rb", 65536))
+ return serializer.load_stream(sockfile)
def ignore_unicode_prefix(f):
@@ -332,7 +341,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
- return map(f, iterator)
+ return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
@@ -347,7 +356,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
- return chain.from_iterable(map(f, iterator))
+ return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
@@ -410,7 +419,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
- return filter(f, iterator)
+ return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)
def distinct(self, numPartitions=None):
@@ -791,6 +800,8 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
+ f = fail_on_stopiteration(f)
+
def processPartition(iterator):
for x in iterator:
f(x)
@@ -822,8 +833,8 @@ def collect(self):
to be small, as all the data is loaded into the driver's memory.
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
- return list(_load_from_socket(port, self._jrdd_deserializer))
+ sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(sock_info, self._jrdd_deserializer))
def reduce(self, f):
"""
@@ -840,6 +851,8 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
+ f = fail_on_stopiteration(f)
+
def func(iterator):
iterator = iter(iterator)
try:
@@ -911,6 +924,8 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
+ op = fail_on_stopiteration(op)
+
def func(iterator):
acc = zeroValue
for obj in iterator:
@@ -943,6 +958,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
+ seqOp = fail_on_stopiteration(seqOp)
+ combOp = fail_on_stopiteration(combOp)
+
def func(iterator):
acc = zeroValue
for obj in iterator:
@@ -1636,6 +1654,8 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
+ func = fail_on_stopiteration(func)
+
def reducePartition(iterator):
m = {}
for k, v in iterator:
@@ -2380,8 +2400,8 @@ def toLocalIterator(self):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
- return _load_from_socket(port, self._jrdd_deserializer)
+ sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(sock_info, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
@@ -2498,7 +2518,7 @@ def _test():
globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 91a7f093cec19..4c16b5fc26f3d 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -33,8 +33,9 @@
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.stop()
-PySpark serialize objects in batches; By default, the batch size is chosen based
-on the size of objects, also configurable by SparkContext's C{batchSize} parameter:
+PySpark serializes objects in batches; by default, the batch size is chosen based
+on the size of objects and is also configurable by SparkContext's C{batchSize}
+parameter:
>>> sc = SparkContext('local', 'test', batchSize=2)
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
@@ -68,6 +69,7 @@
xrange = range
from pyspark import cloudpickle
+from pyspark.util import _exception_message
__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
@@ -99,7 +101,7 @@ def load_stream(self, stream):
def _load_stream_without_unbatching(self, stream):
"""
Return an iterator of deserialized batches (iterable) of objects from the input stream.
- if the serializer does not operate on batches the default implementation returns an
+ If the serializer does not operate on batches the default implementation returns an
iterator of single element lists.
"""
return map(lambda x: [x], self.load_stream(stream))
@@ -249,6 +251,15 @@ def __init__(self, timezone):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
+ def arrow_to_pandas(self, arrow_column):
+ from pyspark.sql.types import from_arrow_type, \
+ _check_series_convert_date, _check_series_localize_timestamps
+
+ s = arrow_column.to_pandas()
+ s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
+ s = _check_series_localize_timestamps(s, self._timezone)
+ return s
+
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@@ -271,16 +282,11 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
- from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
- _check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
- schema = from_arrow_schema(reader.schema)
+
for batch in reader:
- pdf = batch.to_pandas()
- pdf = _check_dataframe_convert_date(pdf, schema)
- pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
- yield [c for _, c in pdf.iteritems()]
+ yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
def __repr__(self):
return "ArrowStreamPandasSerializer"
@@ -456,7 +462,7 @@ def dumps(self, obj):
return obj
-# Hook namedtuple, make it picklable
+# Hack namedtuple, make it picklable
__cls = {}
@@ -520,15 +526,15 @@ def namedtuple(*args, **kwargs):
cls = _old_namedtuple(*args, **kwargs)
return _hack_namedtuple(cls)
- # replace namedtuple with new one
+ # replace namedtuple with the new one
collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults
collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple
collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple
collections.namedtuple.__code__ = namedtuple.__code__
collections.namedtuple.__hijack = 1
- # hack the cls already generated by namedtuple
- # those created in other module can be pickled as normal,
+ # hack the cls already generated by namedtuple.
+ # Those created in other modules can be pickled as normal,
# so only hack those in __main__ module
for n, o in sys.modules["__main__"].__dict__.items():
if (type(o) is type and o.__base__ is tuple
@@ -565,7 +571,18 @@ def loads(self, obj, encoding=None):
class CloudPickleSerializer(PickleSerializer):
def dumps(self, obj):
- return cloudpickle.dumps(obj, 2)
+ try:
+ return cloudpickle.dumps(obj, 2)
+ except pickle.PickleError:
+ raise
+ except Exception as e:
+ emsg = _exception_message(e)
+ if "'i' format requires" in emsg:
+ msg = "Object too large to serialize: %s" % emsg
+ else:
+ msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
+ cloudpickle.print_exec(sys.stderr)
+ raise pickle.PicklingError(msg)
class MarshalSerializer(FramedSerializer):
@@ -611,7 +628,7 @@ def loads(self, obj):
elif _type == b'P':
return pickle.loads(obj[1:])
else:
- raise ValueError("invalid sevialization type: %s" % _type)
+ raise ValueError("invalid serialization type: %s" % _type)
class CompressedSerializer(FramedSerializer):
@@ -699,4 +716,4 @@ def write_with_length(obj, stream):
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index b5fcf7092d93a..472c3cd4452f0 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -38,25 +38,13 @@
SparkContext._ensure_initialized()
try:
- # Try to access HiveConf, it will raise exception if Hive is not added
- conf = SparkConf()
- if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
- SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf()
- spark = SparkSession.builder\
- .enableHiveSupport()\
- .getOrCreate()
- else:
- spark = SparkSession.builder.getOrCreate()
-except py4j.protocol.Py4JError:
- if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
- warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
- "please make sure you build spark with hive")
- spark = SparkSession.builder.getOrCreate()
-except TypeError:
- if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
- warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
- "please make sure you build spark with hive")
- spark = SparkSession.builder.getOrCreate()
+ spark = SparkSession._create_shell_session()
+except Exception:
+ import sys
+ import traceback
+ warnings.warn("Failed to initialize Spark session.")
+ traceback.print_exc(file=sys.stderr)
+ sys.exit(1)
sc = spark.sparkContext
sql = spark.sql
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index e974cda9fc3e1..bd0ac0039ffe1 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -23,10 +23,12 @@
import itertools
import operator
import random
+import sys
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
+from pyspark.util import fail_on_stopiteration
try:
@@ -93,9 +95,9 @@ class Aggregator(object):
"""
def __init__(self, createCombiner, mergeValue, mergeCombiners):
- self.createCombiner = createCombiner
- self.mergeValue = mergeValue
- self.mergeCombiners = mergeCombiners
+ self.createCombiner = fail_on_stopiteration(createCombiner)
+ self.mergeValue = fail_on_stopiteration(mergeValue)
+ self.mergeCombiners = fail_on_stopiteration(mergeCombiners)
class SimpleAggregator(Aggregator):
@@ -810,4 +812,4 @@ def load_partition(j):
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 6aef0f22340be..b0d8357f4feec 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import sys
import warnings
from collections import namedtuple
@@ -306,7 +307,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 43b38a2cd477c..e7dec11c69b57 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -16,7 +16,6 @@
#
import sys
-import warnings
import json
if sys.version >= '3':
@@ -448,24 +447,72 @@ def isin(self, *cols):
# order
_asc_doc = """
- Returns a sort expression based on the ascending order of the given column name
+ Returns a sort expression based on ascending order of the column.
>>> from pyspark.sql import Row
- >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)])
+ >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.asc()).collect()
[Row(name=u'Alice'), Row(name=u'Tom')]
"""
+ _asc_nulls_first_doc = """
+ Returns a sort expression based on ascending order of the column, and null values
+ return before non-null values.
+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
+ >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect()
+ [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]
+
+ .. versionadded:: 2.4
+ """
+ _asc_nulls_last_doc = """
+ Returns a sort expression based on ascending order of the column, and null values
+ appear after non-null values.
+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
+ >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect()
+ [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]
+
+ .. versionadded:: 2.4
+ """
_desc_doc = """
- Returns a sort expression based on the descending order of the given column name.
+ Returns a sort expression based on the descending order of the column.
>>> from pyspark.sql import Row
- >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)])
+ >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"])
>>> df.select(df.name).orderBy(df.name.desc()).collect()
[Row(name=u'Tom'), Row(name=u'Alice')]
"""
+ _desc_nulls_first_doc = """
+ Returns a sort expression based on the descending order of the column, and null values
+ appear before non-null values.
+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
+ >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect()
+ [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]
+
+ .. versionadded:: 2.4
+ """
+ _desc_nulls_last_doc = """
+ Returns a sort expression based on the descending order of the column, and null values
+ appear after non-null values.
+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"])
+ >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect()
+ [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]
+
+ .. versionadded:: 2.4
+ """
asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc))
+ asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc))
+ asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc))
desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc))
+ desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc))
+ desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc))
_isNull_doc = """
True if the current expression is null.
@@ -660,7 +707,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index 792c420ca6386..db49040e17b63 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -15,7 +15,9 @@
# limitations under the License.
#
-from pyspark import since
+import sys
+
+from pyspark import since, _NoValue
from pyspark.rdd import ignore_unicode_prefix
@@ -37,15 +39,16 @@ def set(self, key, value):
@ignore_unicode_prefix
@since(2.0)
- def get(self, key, default=None):
+ def get(self, key, default=_NoValue):
"""Returns the value of Spark runtime configuration property for the given key,
assuming it is set.
"""
self._checkType(key, "key")
- if default is None:
+ if default is _NoValue:
return self._jconf.get(key)
else:
- self._checkType(default, "default")
+ if default is not None:
+ self._checkType(default, "default")
return self._jconf.get(key, default)
@ignore_unicode_prefix
@@ -64,7 +67,6 @@ def _checkType(self, obj, identifier):
def _test():
import os
import doctest
- from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
import pyspark.sql.conf
@@ -80,7 +82,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index cc1cd1a5842d9..e9ec7ba866761 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -22,7 +22,7 @@
if sys.version >= '3':
basestring = unicode = str
-from pyspark import since
+from pyspark import since, _NoValue
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
@@ -124,11 +124,11 @@ def setConf(self, key, value):
@ignore_unicode_prefix
@since(1.3)
- def getConf(self, key, defaultValue=None):
+ def getConf(self, key, defaultValue=_NoValue):
"""Returns the value of Spark SQL configuration property for the given key.
- If the key is not set and defaultValue is not None, return
- defaultValue. If the key is not set and defaultValue is None, return
+ If the key is not set and defaultValue is set, return
+ defaultValue. If the key is not set and defaultValue is not set, return
the system default value.
>>> sqlContext.getConf("spark.sql.shuffle.partitions")
@@ -543,7 +543,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 59a417015b949..1e6a1acebb5ca 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -27,7 +27,7 @@
import warnings
-from pyspark import copy_func, since
+from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
@@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx):
self.is_cached = False
self._schema = None # initialized lazily
self._lazy_rdd = None
+ # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opened.
+ self._support_repr_html = False
@property
@since(1.3)
@@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False):
else:
print(self._jdf.showString(n, int(truncate), vertical))
+ @property
+ def _eager_eval(self):
+ """Returns true if the eager evaluation enabled.
+ """
+ return self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.enabled", "false").lower() == "true"
+
+ @property
+ def _max_num_rows(self):
+ """Returns the max row number for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.maxNumRows", "20"))
+
+ @property
+ def _truncate(self):
+ """Returns the truncate length for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.truncate", "20"))
+
def __repr__(self):
- return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+ if not self._support_repr_html and self._eager_eval:
+ vertical = False
+ return self._jdf.showString(
+ self._max_num_rows, self._truncate, vertical)
+ else:
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+
+ def _repr_html_(self):
+ """Returns a dataframe with html code when you enabled eager evaluation
+ by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+ using support eager evaluation with HTML.
+ """
+ import cgi
+ if not self._support_repr_html:
+ self._support_repr_html = True
+ if self._eager_eval:
+ max_num_rows = max(self._max_num_rows, 0)
+ vertical = False
+ sock_info = self._jdf.getRowsToPython(
+ max_num_rows, self._truncate, vertical)
+ rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
+ head = rows[0]
+ row_data = rows[1:]
+ has_more_data = len(row_data) > max_num_rows
+ row_data = row_data[:max_num_rows]
+
+ html = "
\n"
+ # generate table head
+ html += "
%s
\n" % "
".join(map(lambda x: cgi.escape(x), head))
+ # generate table rows
+ for row in row_data:
+ html += "
%s
\n" % "
".join(
+ map(lambda x: cgi.escape(x), row))
+ html += "
\n"
+ if has_more_data:
+ html += "only showing top %d %s\n" % (
+ max_num_rows, "row" if max_num_rows == 1 else "rows")
+ return html
+ else:
+ return None
@since(2.1)
def checkpoint(self, eager=True):
@@ -463,8 +526,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.collectToPython()
- return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
+ sock_info = self._jdf.collectToPython()
+ return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
@since(2.0)
@@ -477,8 +540,8 @@ def toLocalIterator(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.toPythonIterator()
- return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+ sock_info = self._jdf.toPythonIterator()
+ return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix
@since(1.3)
@@ -588,6 +651,8 @@ def coalesce(self, numPartitions):
"""
Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions.
+ :param numPartitions: int, to specify the target number of partitions
+
Similar to coalesce defined on an :class:`RDD`, this operation results in a
narrow dependency, e.g. if you go from 1000 partitions to 100 partitions,
there will not be a shuffle, instead each of the 100 new partitions will
@@ -612,9 +677,10 @@ def repartition(self, numPartitions, *cols):
Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
resulting DataFrame is hash partitioned.
- ``numPartitions`` can be an int to specify the target number of partitions or a Column.
- If it is a Column, it will be used as the first partitioning column. If not specified,
- the default number of partitions is used.
+ :param numPartitions:
+ can be an int to specify the target number of partitions or a Column.
+ If it is a Column, it will be used as the first partitioning column. If not specified,
+ the default number of partitions is used.
.. versionchanged:: 1.6
Added optional arguments to specify the partitioning columns. Also made numPartitions
@@ -667,6 +733,52 @@ def repartition(self, numPartitions, *cols):
else:
raise TypeError("numPartitions should be an int or Column")
+ @since("2.4.0")
+ def repartitionByRange(self, numPartitions, *cols):
+ """
+ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The
+ resulting DataFrame is range partitioned.
+
+ :param numPartitions:
+ can be an int to specify the target number of partitions or a Column.
+ If it is a Column, it will be used as the first partitioning column. If not specified,
+ the default number of partitions is used.
+
+ At least one partition-by expression must be specified.
+ When no explicit sort order is specified, "ascending nulls first" is assumed.
+
+ >>> df.repartitionByRange(2, "age").rdd.getNumPartitions()
+ 2
+ >>> df.show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 2|Alice|
+ | 5| Bob|
+ +---+-----+
+ >>> df.repartitionByRange(1, "age").rdd.getNumPartitions()
+ 1
+ >>> data = df.repartitionByRange("age")
+ >>> df.show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 2|Alice|
+ | 5| Bob|
+ +---+-----+
+ """
+ if isinstance(numPartitions, int):
+ if len(cols) == 0:
+ return ValueError("At least one partition-by expression must be specified.")
+ else:
+ return DataFrame(
+ self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx)
+ elif isinstance(numPartitions, (basestring, Column)):
+ cols = (numPartitions,) + cols
+ return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx)
+ else:
+ raise TypeError("numPartitions should be an int, string or Column")
+
@since(1.3)
def distinct(self):
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.
@@ -847,6 +959,8 @@ def colRegex(self, colName):
def alias(self, alias):
"""Returns a new :class:`DataFrame` with an alias set.
+ :param alias: string, an alias name to be set for the DataFrame.
+
>>> from pyspark.sql.functions import *
>>> df_as1 = df.alias("df_as1")
>>> df_as2 = df.alias("df_as2")
@@ -1532,7 +1646,7 @@ def fillna(self, value, subset=None):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@since(1.4)
- def replace(self, to_replace, value=None, subset=None):
+ def replace(self, to_replace, value=_NoValue, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
aliases of each other.
@@ -1545,8 +1659,8 @@ def replace(self, to_replace, value=None, subset=None):
:param to_replace: bool, int, long, float, string, list or dict.
Value to be replaced.
- If the value is a dict, then `value` is ignored and `to_replace` must be a
- mapping between a value and a replacement.
+ If the value is a dict, then `value` is ignored or can be omitted, and `to_replace`
+ must be a mapping between a value and a replacement.
:param value: bool, int, long, float, string, list or None.
The replacement value must be a bool, int, long, float, string or None. If `value` is a
list, `value` should be of the same length and type as `to_replace`.
@@ -1577,6 +1691,16 @@ def replace(self, to_replace, value=None, subset=None):
|null| null|null|
+----+------+----+
+ >>> df4.na.replace({'Alice': None}).show()
+ +----+------+----+
+ | age|height|name|
+ +----+------+----+
+ | 10| 80|null|
+ | 5| null| Bob|
+ |null| null| Tom|
+ |null| null|null|
+ +----+------+----+
+
>>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+----+------+----+
| age|height|name|
@@ -1587,6 +1711,12 @@ def replace(self, to_replace, value=None, subset=None):
|null| null|null|
+----+------+----+
"""
+ if value is _NoValue:
+ if isinstance(to_replace, dict):
+ value = None
+ else:
+ raise TypeError("value argument is required when to_replace is not a dictionary.")
+
# Helper functions
def all_of(types):
"""Given a type or tuple of types and a sequence of xs
@@ -1839,7 +1969,7 @@ def withColumnRenamed(self, existing, new):
This is a no-op if schema doesn't contain the given column name.
:param existing: string, name of the existing column to rename.
- :param col: string, new name of the column.
+ :param new: string, new name of the column.
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
@@ -1908,11 +2038,16 @@ def toPandas(self):
.. note:: This method should only be used if the resulting Pandas's DataFrame is expected
to be small, as all the data is loaded into the driver's memory.
+ .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental.
+
>>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
"""
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
@@ -1922,52 +2057,92 @@ def toPandas(self):
timezone = None
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
+ use_arrow = True
try:
- from pyspark.sql.types import _check_dataframe_convert_date, \
- _check_dataframe_localize_timestamps
+ from pyspark.sql.types import to_arrow_schema
from pyspark.sql.utils import require_minimum_pyarrow_version
- import pyarrow
+
require_minimum_pyarrow_version()
- tables = self._collectAsArrow()
- if tables:
- table = pyarrow.concat_tables(tables)
- pdf = table.to_pandas()
- pdf = _check_dataframe_convert_date(pdf, self.schema)
- return _check_dataframe_localize_timestamps(pdf, timezone)
+ to_arrow_schema(self.schema)
+ except Exception as e:
+
+ if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \
+ .lower() == "true":
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "Attempting non-optimization as "
+ "'spark.sql.execution.arrow.fallback.enabled' is set to "
+ "true." % _exception_message(e))
+ warnings.warn(msg)
+ use_arrow = False
else:
- return pd.DataFrame.from_records([], columns=self.columns)
- except ImportError as e:
- msg = "note: pyarrow must be installed and available on calling Python process " \
- "if using spark.sql.execution.arrow.enabled=true"
- raise ImportError("%s\n%s" % (_exception_message(e), msg))
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true, but has reached "
+ "the error below and will not continue because automatic fallback "
+ "with 'spark.sql.execution.arrow.fallback.enabled' has been set to "
+ "false.\n %s" % _exception_message(e))
+ warnings.warn(msg)
+ raise
+
+ # Try to use Arrow optimization when the schema is supported and the required version
+ # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
+ if use_arrow:
+ try:
+ from pyspark.sql.types import _check_dataframe_convert_date, \
+ _check_dataframe_localize_timestamps
+ import pyarrow
+
+ tables = self._collectAsArrow()
+ if tables:
+ table = pyarrow.concat_tables(tables)
+ pdf = table.to_pandas()
+ pdf = _check_dataframe_convert_date(pdf, self.schema)
+ return _check_dataframe_localize_timestamps(pdf, timezone)
+ else:
+ return pd.DataFrame.from_records([], columns=self.columns)
+ except Exception as e:
+ # We might have to allow fallback here as well but multiple Spark jobs can
+ # be executed. So, simply fail in this case for now.
+ msg = (
+ "toPandas attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true, but has reached "
+ "the error below and can not continue. Note that "
+ "'spark.sql.execution.arrow.fallback.enabled' does not have an effect "
+ "on failures in the middle of computation.\n %s" % _exception_message(e))
+ warnings.warn(msg)
+ raise
+
+ # Below is toPandas without Arrow optimization.
+ pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+
+ dtype = {}
+ for field in self.schema:
+ pandas_type = _to_corrected_pandas_type(field.dataType)
+ # SPARK-21766: if an integer field is nullable and has null values, it can be
+ # inferred by pandas as float column. Once we convert the column with NaN back
+ # to integer type e.g., np.int16, we will hit exception. So we use the inferred
+ # float type, not the corrected type from the schema in this case.
+ if pandas_type is not None and \
+ not(isinstance(field.dataType, IntegralType) and field.nullable and
+ pdf[field.name].isnull().any()):
+ dtype[field.name] = pandas_type
+
+ for f, t in dtype.items():
+ pdf[f] = pdf[f].astype(t, copy=False)
+
+ if timezone is None:
+ return pdf
else:
- pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
-
- dtype = {}
+ from pyspark.sql.types import _check_series_convert_timestamps_local_tz
for field in self.schema:
- pandas_type = _to_corrected_pandas_type(field.dataType)
- # SPARK-21766: if an integer field is nullable and has null values, it can be
- # inferred by pandas as float column. Once we convert the column with NaN back
- # to integer type e.g., np.int16, we will hit exception. So we use the inferred
- # float type, not the corrected type from the schema in this case.
- if pandas_type is not None and \
- not(isinstance(field.dataType, IntegralType) and field.nullable and
- pdf[field.name].isnull().any()):
- dtype[field.name] = pandas_type
-
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
-
- if timezone is None:
- return pdf
- else:
- from pyspark.sql.types import _check_series_convert_timestamps_local_tz
- for field in self.schema:
- # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
- if isinstance(field.dataType, TimestampType):
- pdf[field.name] = \
- _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
- return pdf
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if isinstance(field.dataType, TimestampType):
+ pdf[field.name] = \
+ _check_series_convert_timestamps_local_tz(pdf[field.name], timezone)
+ return pdf
def _collectAsArrow(self):
"""
@@ -1977,8 +2152,8 @@ def _collectAsArrow(self):
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.collectAsArrowToPython()
- return list(_load_from_socket(port, ArrowSerializer()))
+ sock_info = self._jdf.collectAsArrowToPython()
+ return list(_load_from_socket(sock_info, ArrowSerializer()))
##########################################################################################
# Pandas compatibility
@@ -2044,7 +2219,7 @@ def fill(self, value, subset=None):
fill.__doc__ = DataFrame.fillna.__doc__
- def replace(self, to_replace, value, subset=None):
+ def replace(self, to_replace, value=_NoValue, subset=None):
return self.df.replace(to_replace, value, subset)
replace.__doc__ = DataFrame.replace.__doc__
@@ -2122,7 +2297,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 05031f5ec87d7..e6346691fb1d4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -18,7 +18,6 @@
"""
A collections of builtin functions
"""
-import math
import sys
import functools
import warnings
@@ -28,10 +27,10 @@
from pyspark import since, SparkContext
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
+# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf
@@ -106,18 +105,15 @@ def _():
_functions_1_4 = {
# unary math functions
- 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
- '0.0 through pi.',
- 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
- '-pi/2 through pi/2.',
- 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' +
- '-pi/2 through pi/2',
+ 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
+ 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
+ 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
- 'cos': """Computes the cosine of the given value.
-
- :param col: :class:`DoubleType` column, units in radians.""",
- 'cosh': 'Computes the hyperbolic cosine of the given value.',
+ 'cos': """:param col: angle in radians
+ :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""",
+ 'cosh': """:param col: hyperbolic angle
+ :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""",
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
'floor': 'Computes the floor of the given value.',
@@ -127,22 +123,38 @@ def _():
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
- 'sin': """Computes the sine of the given value.
-
- :param col: :class:`DoubleType` column, units in radians.""",
- 'sinh': 'Computes the hyperbolic sine of the given value.',
- 'tan': """Computes the tangent of the given value.
-
- :param col: :class:`DoubleType` column, units in radians.""",
- 'tanh': 'Computes the hyperbolic tangent of the given value.',
+ 'sin': """:param col: angle in radians
+ :return: sine of the angle, as if computed by `java.lang.Math.sin()`""",
+ 'sinh': """:param col: hyperbolic angle
+ :return: hyperbolic sine of the given value,
+ as if computed by `java.lang.Math.sinh()`""",
+ 'tan': """:param col: angle in radians
+ :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""",
+ 'tanh': """:param col: hyperbolic angle
+ :return: hyperbolic tangent of the given value,
+ as if computed by `java.lang.Math.tanh()`""",
'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.',
'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.',
'bitwiseNOT': 'Computes bitwise not.',
}
+_functions_2_4 = {
+ 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
+ ' column name, and null values return before non-null values.',
+ 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
+ ' column name, and null values appear after non-null values.',
+ 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' +
+ ' column name, and null values appear before non-null values.',
+ 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' +
+ ' column name, and null values appear after non-null values',
+}
+
_collect_list_doc = """
Aggregate function: returns a list of objects with duplicates.
+ .. note:: The function is non-deterministic because the order of collected results depends
+ on order of rows which may be non-deterministic after a shuffle.
+
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
>>> df2.agg(collect_list('age')).collect()
[Row(collect_list(age)=[2, 5, 5])]
@@ -150,6 +162,9 @@ def _():
_collect_set_doc = """
Aggregate function: returns a set of objects with duplicate elements eliminated.
+ .. note:: The function is non-deterministic because the order of collected results depends
+ on order of rows which may be non-deterministic after a shuffle.
+
>>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
>>> df2.agg(collect_set('age')).collect()
[Row(collect_set(age)=[5, 2])]
@@ -173,16 +188,31 @@ def _():
_functions_2_1 = {
# unary math functions
- 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' +
- 'measured in degrees.',
- 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
- 'measured in radians.',
+ 'degrees': """
+ Converts an angle measured in radians to an approximately equivalent angle
+ measured in degrees.
+ :param col: angle in radians
+ :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()`
+ """,
+ 'radians': """
+ Converts an angle measured in degrees to an approximately equivalent angle
+ measured in radians.
+ :param col: angle in degrees
+ :return: angle in radians, as if computed by `java.lang.Math.toRadians()`
+ """,
}
# math functions that take two arguments as input
_binary_mathfunctions = {
- 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
- 'polar coordinates (r, theta). Units in radians.',
+ 'atan2': """
+ :param col1: coordinate on y-axis
+ :param col2: coordinate on x-axis
+ :return: the `theta` component of the point
+ (`r`, `theta`)
+ in polar coordinates that corresponds to the point
+ (`x`, `y`) in Cartesian coordinates,
+ as if computed by `java.lang.Math.atan2()`
+ """,
'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.',
}
@@ -237,6 +267,8 @@ def _():
globals()[_name] = since(2.1)(_create_function(_name, _doc))
for _name, _message in _functions_deprecated.items():
globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
+for _name, _doc in _functions_2_4.items():
+ globals()[_name] = since(2.4)(_create_function(_name, _doc))
del _name, _doc
@@ -375,6 +407,9 @@ def first(col, ignorenulls=False):
The function by default returns the first values it sees. It will return the first non-null
value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+
+ .. note:: The function is non-deterministic because its results depends on order of rows which
+ may be non-deterministic after a shuffle.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
@@ -463,6 +498,9 @@ def last(col, ignorenulls=False):
The function by default returns the last values it sees. It will return the last non-null
value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+
+ .. note:: The function is non-deterministic because its results depends on order of rows
+ which may be non-deterministic after a shuffle.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
@@ -478,6 +516,8 @@ def monotonically_increasing_id():
within each partition in the lower 33 bits. The assumption is that the data frame has
less than 1 billion partitions, and each partition has less than 8 billion records.
+ .. note:: The function is non-deterministic because its result depends on partition IDs.
+
As an example, consider a :class:`DataFrame` with two partitions, each with 3 records.
This expression would return the following IDs:
0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
@@ -510,6 +550,8 @@ def rand(seed=None):
"""Generates a random column with independent and identically distributed (i.i.d.) samples
from U[0.0, 1.0].
+ .. note:: The function is non-deterministic in general case.
+
>>> df.withColumn('rand', rand(seed=42) * 3).collect()
[Row(age=2, name=u'Alice', rand=1.1568609015300986),
Row(age=5, name=u'Bob', rand=1.403379671529166)]
@@ -528,6 +570,8 @@ def randn(seed=None):
"""Generates a column with independent and identically distributed (i.i.d.) samples from
the standard normal distribution.
+ .. note:: The function is non-deterministic in general case.
+
>>> df.withColumn('randn', randn(seed=42)).collect()
[Row(age=2, name=u'Alice', randn=-0.7556247885860078),
Row(age=5, name=u'Bob', randn=-0.0861619008451133)]
@@ -809,6 +853,36 @@ def ntile(n):
return Column(sc._jvm.functions.ntile(int(n)))
+@since(2.4)
+def unboundedPreceding():
+ """
+ Window function: returns the special frame boundary that represents the first row
+ in the window partition.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.unboundedPreceding())
+
+
+@since(2.4)
+def unboundedFollowing():
+ """
+ Window function: returns the special frame boundary that represents the last row
+ in the window partition.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.unboundedFollowing())
+
+
+@since(2.4)
+def currentRow():
+ """
+ Window function: returns the special frame boundary that represents the current row
+ in the window partition.
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.currentRow())
+
+
# ---------------------- Date/Timestamp functions ------------------------------
@since(1.5)
@@ -1032,16 +1106,23 @@ def add_months(start, months):
@since(1.5)
-def months_between(date1, date2):
+def months_between(date1, date2, roundOff=True):
"""
- Returns the number of months between date1 and date2.
+ Returns number of months between dates date1 and date2.
+ If date1 is later than date2, then the result is positive.
+ If date1 and date2 are on the same day of month, or both are the last day of month,
+ returns an integer (time of day will be ignored).
+ The result is rounded off to 8 digits unless `roundOff` is set to `False`.
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
>>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
- [Row(months=3.9495967...)]
+ [Row(months=3.94959677)]
+ >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect()
+ [Row(months=3.9495967741935485)]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))
+ return Column(sc._jvm.functions.months_between(
+ _to_java_column(date1), _to_java_column(date2), roundOff))
@since(2.2)
@@ -1358,7 +1439,6 @@ def hash(*cols):
'uppercase. Words are delimited by whitespace.',
'lower': 'Converts a string column to lower case.',
'upper': 'Converts a string column to upper case.',
- 'reverse': 'Reverses the string column and returns it as a new string column.',
'ltrim': 'Trim the spaces from left end for the specified string value.',
'rtrim': 'Trim the spaces from right end for the specified string value.',
'trim': 'Trim the spaces from both ends for the specified string column.',
@@ -1370,21 +1450,6 @@ def hash(*cols):
del _name, _doc
-@since(1.5)
-@ignore_unicode_prefix
-def concat(*cols):
- """
- Concatenates multiple input columns together into a single column.
- If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
-
- >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
- >>> df.select(concat(df.s, df.d).alias('s')).collect()
- [Row(s=u'abcd123')]
- """
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
-
-
@since(1.5)
@ignore_unicode_prefix
def concat_ws(sep, *cols):
@@ -1754,6 +1819,25 @@ def create_map(*cols):
return Column(jc)
+@since(2.4)
+def map_from_arrays(col1, col2):
+ """Creates a new map from two arrays.
+
+ :param col1: name of column containing a set of keys. All elements should not be null
+ :param col2: name of column containing a set of values
+
+ >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
+ >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()
+ +----------------+
+ | map|
+ +----------------+
+ |[2 -> a, 5 -> b]|
+ +----------------+
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2)))
+
+
@since(1.4)
def array(*cols):
"""Creates a new array column.
@@ -1790,6 +1874,131 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
+@since(2.4)
+def arrays_overlap(a1, a2):
+ """
+ Collection function: returns true if the arrays contain any common non-null element; if not,
+ returns null if both the arrays are non-empty and any of them contains a null element; returns
+ false otherwise.
+
+ >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
+ >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
+ [Row(overlap=True), Row(overlap=False)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
+
+
+@since(2.4)
+def slice(x, start, length):
+ """
+ Collection function: returns an array containing all the elements in `x` from index `start`
+ (or starting from the end if `start` is negative) with the specified `length`.
+ >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
+ >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
+ [Row(sliced=[2, 3]), Row(sliced=[5])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))
+
+
+@ignore_unicode_prefix
+@since(2.4)
+def array_join(col, delimiter, null_replacement=None):
+ """
+ Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+ `null_replacement` if set, otherwise they are ignored.
+
+ >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
+ >>> df.select(array_join(df.data, ",").alias("joined")).collect()
+ [Row(joined=u'a,b,c'), Row(joined=u'a')]
+ >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
+ [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')]
+ """
+ sc = SparkContext._active_spark_context
+ if null_replacement is None:
+ return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter))
+ else:
+ return Column(sc._jvm.functions.array_join(
+ _to_java_column(col), delimiter, null_replacement))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def concat(*cols):
+ """
+ Concatenates multiple input columns together into a single column.
+ The function works with strings, binary and compatible array columns.
+
+ >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df.select(concat(df.s, df.d).alias('s')).collect()
+ [Row(s=u'abcd123')]
+
+ >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
+ >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
+ [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
+
+
+@since(2.4)
+def array_position(col, value):
+ """
+ Collection function: Locates the position of the first occurrence of the given value
+ in the given array. Returns null if either of the arguments are null.
+
+ .. note:: The position is not zero based, but 1 based index. Returns 0 if the given
+ value could not be found in the array.
+
+ >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
+ >>> df.select(array_position(df.data, "a")).collect()
+ [Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_position(_to_java_column(col), value))
+
+
+@ignore_unicode_prefix
+@since(2.4)
+def element_at(col, extraction):
+ """
+ Collection function: Returns element of array at given index in extraction if col is array.
+ Returns value for the given key in extraction if col is map.
+
+ :param col: name of column containing array or map
+ :param extraction: index to check for in array or key to check for in map
+
+ .. note:: The position is not zero based, but 1 based index.
+
+ >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+ >>> df.select(element_at(df.data, 1)).collect()
+ [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
+
+ >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
+ >>> df.select(element_at(df.data, "a")).collect()
+ [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction))
+
+
+@since(2.4)
+def array_remove(col, element):
+ """
+ Collection function: Remove all elements that equal to element from the given array.
+
+ :param col: name of column containing array
+ :param element: element to be removed from the array
+
+ >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])
+ >>> df.select(array_remove(df.data, 1)).collect()
+ [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_remove(_to_java_column(col), element))
+
+
@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
@@ -1936,12 +2145,13 @@ def json_tuple(col, *fields):
return Column(jc)
+@ignore_unicode_prefix
@since(2.1)
def from_json(col, schema, options={}):
"""
- Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType`
- of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an
- unparseable string.
+ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
+ as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with
+ the specified schema. Returns `null`, in the case of an unparseable string.
:param col: string column in json format
:param schema: a StructType or ArrayType of StructType to use when parsing the json column.
@@ -1958,6 +2168,8 @@ def from_json(col, schema, options={}):
[Row(json=Row(a=1))]
>>> df.select(from_json(df.value, "a INT").alias("json")).collect()
[Row(json=Row(a=1))]
+ >>> df.select(from_json(df.value, "MAP").alias("json")).collect()
+ [Row(json={u'a': 1})]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
@@ -2024,24 +2236,108 @@ def size(col):
return Column(sc._jvm.functions.size(_to_java_column(col)))
+@since(2.4)
+def array_min(col):
+ """
+ Collection function: returns the minimum value of the array.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+ >>> df.select(array_min(df.data).alias('min')).collect()
+ [Row(min=1), Row(min=-1)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_min(_to_java_column(col)))
+
+
+@since(2.4)
+def array_max(col):
+ """
+ Collection function: returns the maximum value of the array.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+ >>> df.select(array_max(df.data).alias('max')).collect()
+ [Row(max=3), Row(max=10)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_max(_to_java_column(col)))
+
+
@since(1.5)
def sort_array(col, asc=True):
"""
Collection function: sorts the input array in ascending or descending order according
- to the natural ordering of the array elements.
+ to the natural ordering of the array elements. Null elements will be placed at the beginning
+ of the returned array in ascending order or at the end of the returned array in descending
+ order.
:param col: name of column or expression
- >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
+ >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(sort_array(df.data).alias('r')).collect()
- [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
+ [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
- [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
- """
+ [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
+ """
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
+@since(2.4)
+def array_sort(col):
+ """
+ Collection function: sorts the input array in ascending order. The elements of the input array
+ must be orderable. Null elements will be placed at the end of the returned array.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
+ >>> df.select(array_sort(df.data).alias('r')).collect()
+ [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_sort(_to_java_column(col)))
+
+
+@since(1.5)
+@ignore_unicode_prefix
+def reverse(col):
+ """
+ Collection function: returns a reversed string or an array with reverse order of elements.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
+ >>> df.select(reverse(df.data).alias('s')).collect()
+ [Row(s=u'LQS krapS')]
+ >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
+ >>> df.select(reverse(df.data).alias('r')).collect()
+ [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.reverse(_to_java_column(col)))
+
+
+@since(2.4)
+def flatten(col):
+ """
+ Collection function: creates a single array from an array of arrays.
+ If a structure of nested arrays is deeper than two levels,
+ only one level of nesting is removed.
+
+ :param col: name of column or expression
+
+ >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
+ >>> df.select(flatten(df.data).alias('r')).collect()
+ [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.flatten(_to_java_column(col)))
+
+
@since(2.3)
def map_keys(col):
"""
@@ -2082,6 +2378,57 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
+@since(2.4)
+def map_entries(col):
+ """
+ Collection function: Returns an unordered array of all entries in the given map.
+
+ :param col: name of column or expression
+
+ >>> from pyspark.sql.functions import map_entries
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_entries("data").alias("entries")).show()
+ +----------------+
+ | entries|
+ +----------------+
+ |[[1, a], [2, b]]|
+ +----------------+
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(2.4)
+def array_repeat(col, count):
+ """
+ Collection function: creates an array containing a column repeated count times.
+
+ >>> df = spark.createDataFrame([('ab',)], ['data'])
+ >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+ [Row(r=[u'ab', u'ab', u'ab'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+
+
+@since(2.4)
+def arrays_zip(*cols):
+ """
+ Collection function: Returns a merged array of structs in which the N-th struct contains all
+ N-th values of input arrays.
+
+ :param cols: columns of arrays to be merged.
+
+ >>> from pyspark.sql.functions import arrays_zip
+ >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2'])
+ >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect()
+ [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
+
+
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
@@ -2111,6 +2458,8 @@ def udf(f=None, returnType=StringType()):
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
+ .. note:: The user-defined functions do not take keyword arguments on the calling side.
+
:param f: python function if used as a standalone function
:param returnType: the return type of the user-defined function. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
@@ -2158,6 +2507,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
:param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
Default: SCALAR.
+ .. note:: Experimental
+
The function type of the UDF can be one of the following:
1. SCALAR
@@ -2200,7 +2551,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`.
- The length of the returned `pandas.DataFrame` can be arbitrary.
+ The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be
+ indexed so that their position matches the corresponding field in the schema.
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
@@ -2223,6 +2575,37 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+
+ Alternatively, the user can define a function that takes two arguments.
+ In this case, the grouping key will be passed as the first argument and the data will
+ be passed as the second argument. The grouping key will be passed as a tuple of numpy
+ data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
+ as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
+ This is useful when the user does not want to hardcode grouping key in the function.
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> import pandas as pd # doctest: +SKIP
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v")) # doctest: +SKIP
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
+ ... def mean_udf(key, pdf):
+ ... # key is a tuple of one numpy.int64, which is the value
+ ... # of 'id' for the current group
+ ... return pd.DataFrame([key + (pdf.v.mean(),)])
+ >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+ +---+---+
+ | id| v|
+ +---+---+
+ | 1|1.5|
+ | 2|6.0|
+ +---+---+
+
+ .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
+ recommended to explicitly index the columns by name to ensure the positions are correct,
+ or alternatively use an `OrderedDict`.
+ For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
+ `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
3. GROUPED_AGG
@@ -2232,10 +2615,12 @@ def pandas_udf(f=None, returnType=None, functionType=None):
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
- :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
- output types.
+ :class:`MapType` and :class:`StructType` are currently not supported as output types.
+
+ Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and
+ :class:`pyspark.sql.Window`
- Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`
+ This example shows using grouped aggregated UDFs with groupby:
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
@@ -2252,7 +2637,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 6.0|
+---+-----------+
- .. seealso:: :meth:`pyspark.sql.GroupedData.agg`
+ This example shows using grouped aggregated UDFs as window functions. Note that only
+ unbounded window frame is supported at the moment:
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> from pyspark.sql import Window
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v"))
+ >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
+ ... def mean_udf(v):
+ ... return v.mean()
+ >>> w = Window.partitionBy('id') \\
+ ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+ +---+----+------+
+ | id| v|mean_v|
+ +---+----+------+
+ | 1| 1.0| 1.5|
+ | 1| 2.0| 1.5|
+ | 2| 3.0| 6.0|
+ | 2| 5.0| 6.0|
+ | 2|10.0| 6.0|
+ +---+----+------+
+
+ .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
.. note:: The user-defined functions are considered deterministic by default. Due to
optimization, duplicate invocations may be eliminated or the function may even be invoked
@@ -2269,6 +2678,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
+
+ .. note:: The user-defined functions do not take keyword arguments on the calling side.
"""
# decorator @pandas_udf(returnType, functionType)
is_decorator = f is None or isinstance(f, (str, DataType))
@@ -2335,7 +2746,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index ab646535c864c..0906c9c6b329a 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -15,11 +15,12 @@
# limitations under the License.
#
+import sys
+
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
-from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
+from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import *
__all__ = ["GroupedData"]
@@ -235,6 +236,8 @@ def apply(self, udf):
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
+ .. note:: Experimental
+
:param udf: a grouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.
@@ -299,7 +302,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 49af1bcee5ef8..3efe2adb6e2a4 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -22,7 +22,7 @@
from py4j.java_gateway import JavaClass
-from pyspark import RDD, since, keyword_only
+from pyspark import RDD, since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import _to_seq
from pyspark.sql.types import *
@@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options):
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
:param options: all other string options
- >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
- ... opt2=1, opt3='str')
+ >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned',
+ ... opt1=True, opt2=1, opt3='str')
>>> df.dtypes
[('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
@@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
+ dropFieldIfAllNull=None, encoding=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -209,13 +210,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record, and puts the malformed string into a field configured by \
- ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
- a string type field named ``columnNameOfCorruptRecord`` in an user-defined \
- schema. If a schema does not have the field, it drops corrupt records during \
- parsing. When inferring a schema, it implicitly adds a \
- ``columnNameOfCorruptRecord`` field in an output schema.
+ * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ fields to ``null``. To keep corrupt records, an user can set a string type \
+ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
+ schema does not have the field, it drops corrupt records during parsing. \
+ When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \
+ field in an output schema.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -237,6 +238,17 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param encoding: allows to forcibly set one of standard basic or extended encoding for
+ the JSON files. For example UTF-16BE, UTF-32LE. If None is set,
+ the encoding of input JSON will be detected automatically
+ when the multiLine option is set to ``true``.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+ :param samplingRatio: defines fraction of input JSON objects used for schema inferring.
+ If None is set, it uses the default value, ``1.0``.
+ :param dropFieldIfAllNull: whether to ignore column of all null values or empty
+ array/struct during schema inference. If None is set, it
+ uses the default value, ``false``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -254,7 +266,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
+ samplingRatio=samplingRatio, encoding=encoding)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -304,16 +317,18 @@ def parquet(self, *paths):
@ignore_unicode_prefix
@since(1.6)
- def text(self, paths, wholetext=False):
+ def text(self, paths, wholetext=False, lineSep=None):
"""
Loads text files and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
are any.
- Each line in the text file is a new row in the resulting DataFrame.
+ By default, each line in the text file is a new row in the resulting DataFrame.
:param paths: string, or list of strings, for input path(s).
:param wholetext: if true, read each file from input path(s) as a single row.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
>>> df = spark.read.text('python/test_support/sql/text-test.txt')
>>> df.collect()
@@ -322,7 +337,7 @@ def text(self, paths, wholetext=False):
>>> df.collect()
[Row(value=u'hello\\nthis')]
"""
- self._set_opts(wholetext=wholetext)
+ self._set_opts(wholetext=wholetext, lineSep=lineSep)
if isinstance(paths, basestring):
paths = [paths]
return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths)))
@@ -333,7 +348,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
- columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
+ columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
+ samplingRatio=None, enforceSchema=None):
"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -360,6 +376,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
+ :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
+ forcibly applied to datasource files, and headers in CSV files will be
+ ignored. If the option is set to ``false``, the schema will be
+ validated against all headers in CSV files or the first header in RDD
+ if the ``header`` option is set to ``true``. Field names in the schema
+ and column names in CSV headers are checked by their positions
+ taking into account ``spark.sql.caseSensitive``. If None is set,
+ ``true`` is used by default. Though the default value is ``true``,
+ it is recommended to disable the ``enforceSchema`` option
+ to avoid incorrect results.
:param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
@@ -393,13 +419,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record, and puts the malformed string into a field configured by \
- ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
- a string type field named ``columnNameOfCorruptRecord`` in an \
- user-defined schema. If a schema does not have the field, it drops corrupt \
- records during parsing. When a length of parsed CSV tokens is shorter than \
- an expected length of a schema, it sets `null` for extra fields.
+ * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ fields to ``null``. To keep corrupt records, an user can set a string type \
+ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
+ schema does not have the field, it drops corrupt records during parsing. \
+ A record with less/more tokens than schema is not a corrupted record to CSV. \
+ When it meets a record having fewer tokens than the length of the schema, \
+ sets ``null`` to extra fields. When the record has more tokens than the \
+ length of the schema, it drops extra tokens.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -414,6 +442,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
the quote character. If None is set, the default value is
escape character when escape and quote characters are
different, ``\0`` otherwise.
+ :param samplingRatio: defines fraction of rows used for schema inferring.
+ If None is set, it uses the default value, ``1.0``.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -432,7 +462,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
+ enforceSchema=enforceSchema)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -742,7 +773,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options)
self._jwrite.saveAsTable(name)
@since(1.4)
- def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None):
+ def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None,
+ lineSep=None, encoding=None):
"""Saves the content of the :class:`DataFrame` in JSON format
(`JSON Lines text format or newline-delimited JSON `_) at the
specified path.
@@ -766,12 +798,17 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm
formats follow the formats at ``java.text.SimpleDateFormat``.
This applies to timestamp type. If None is set, it uses the
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``.
+ :param encoding: specifies encoding (charset) of saved json files. If None is set,
+ the default UTF-8 charset will be used.
+ :param lineSep: defines the line separator that should be used for writing. If None is
+ set, it uses the default value, ``\\n``.
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
self._set_opts(
- compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat)
+ compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat,
+ lineSep=lineSep, encoding=encoding)
self._jwrite.json(path)
@since(1.4)
@@ -802,18 +839,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self._jwrite.parquet(path)
@since(1.6)
- def text(self, path, compression=None):
+ def text(self, path, compression=None, lineSep=None):
"""Saves the content of the DataFrame in a text file at the specified path.
:param path: the path in any Hadoop supported file system
:param compression: compression codec to use when saving to file. This can be one of the
known case-insensitive shorten names (none, bzip2, gzip, lz4,
snappy and deflate).
+ :param lineSep: defines the line separator that should be used for writing. If None is
+ set, it uses the default value, ``\\n``.
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
- self._set_opts(compression=compression)
+ self._set_opts(compression=compression, lineSep=lineSep)
self._jwrite.text(path)
@since(2.0)
@@ -954,7 +993,7 @@ def _test():
globs = pyspark.sql.readwriter.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
try:
- spark = SparkSession.builder.enableHiveSupport().getOrCreate()
+ spark = SparkSession.builder.getOrCreate()
except py4j.protocol.Py4JError:
spark = SparkSession(sc)
@@ -968,7 +1007,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
sc.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 1ed04298bc899..f1ad6b1212ed9 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -547,6 +547,33 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
df._schema = schema
return df
+ @staticmethod
+ def _create_shell_session():
+ """
+ Initialize a SparkSession for a pyspark shell session. This is called from shell.py
+ to make error handling simpler without needing to declare local variables in that
+ script, which would expose those to users.
+ """
+ import py4j
+ from pyspark.conf import SparkConf
+ from pyspark.context import SparkContext
+ try:
+ # Try to access HiveConf, it will raise exception if Hive is not added
+ conf = SparkConf()
+ if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
+ SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ return SparkSession.builder\
+ .enableHiveSupport()\
+ .getOrCreate()
+ else:
+ return SparkSession.builder.getOrCreate()
+ except (py4j.protocol.Py4JError, TypeError):
+ if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
+ warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
+ "please make sure you build spark with hive")
+
+ return SparkSession.builder.getOrCreate()
+
@since(2.0)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
@@ -584,6 +611,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
.. versionchanged:: 2.1
Added verifySchema.
+ .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental.
+
>>> l = [('Alice', 1)]
>>> spark.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
@@ -646,6 +675,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
== "true":
timezone = self.conf.get("spark.sql.session.timeZone")
@@ -663,8 +695,27 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
- warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
- # Fallback to create DataFrame without arrow if raise some exception
+ from pyspark.util import _exception_message
+
+ if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \
+ .lower() == "true":
+ msg = (
+ "createDataFrame attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true; however, "
+ "failed by the reason below:\n %s\n"
+ "Attempting non-optimization as "
+ "'spark.sql.execution.arrow.fallback.enabled' is set to "
+ "true." % _exception_message(e))
+ warnings.warn(msg)
+ else:
+ msg = (
+ "createDataFrame attempted Arrow optimization because "
+ "'spark.sql.execution.arrow.enabled' is set to true, but has reached "
+ "the error below and will not continue because automatic fallback "
+ "with 'spark.sql.execution.arrow.fallback.enabled' has been set to "
+ "false.\n %s" % _exception_message(e))
+ warnings.warn(msg)
+ raise
data = self._convert_from_pandas(data, schema, timezone)
if isinstance(schema, StructType):
@@ -809,7 +860,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
_test()
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index e2a97acb5e2a7..4984593bab491 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -24,8 +24,6 @@
else:
intlike = (int, long)
-from abc import ABCMeta, abstractmethod
-
from pyspark import since, keyword_only
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import _to_seq
@@ -407,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
@@ -442,13 +440,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record, and puts the malformed string into a field configured by \
- ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
- a string type field named ``columnNameOfCorruptRecord`` in an user-defined \
- schema. If a schema does not have the field, it drops corrupt records during \
- parsing. When inferring a schema, it implicitly adds a \
- ``columnNameOfCorruptRecord`` field in an output schema.
+ * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ fields to ``null``. To keep corrupt records, an user can set a string type \
+ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
+ schema does not have the field, it drops corrupt records during parsing. \
+ When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \
+ field in an output schema.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -470,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param allowUnquotedControlChars: allows JSON Strings to contain unquoted control
characters (ASCII characters with value less than 32,
including tab and line feed characters) or not.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
@@ -484,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
@@ -531,17 +531,20 @@ def parquet(self, path):
@ignore_unicode_prefix
@since(2.0)
- def text(self, path):
+ def text(self, path, wholetext=False, lineSep=None):
"""
Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a
string column named "value", and followed by partitioned columns if there
are any.
- Each line in the text file is a new row in the resulting DataFrame.
+ By default, each line in the text file is a new row in the resulting DataFrame.
.. note:: Evolving.
:param paths: string, or list of strings, for input path(s).
+ :param wholetext: if true, read each file from input path(s) as a single row.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
>>> text_sdf = spark.readStream.text(tempfile.mkdtemp())
>>> text_sdf.isStreaming
@@ -549,6 +552,7 @@ def text(self, path):
>>> "value" in str(text_sdf.schema)
True
"""
+ self._set_opts(wholetext=wholetext, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.text(path))
else:
@@ -560,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
- columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
+ columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
+ enforceSchema=None):
"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -588,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
default value, ``false``.
:param inferSchema: infers the input schema automatically from data. It requires one extra
pass over the data. If None is set, it uses the default value, ``false``.
+ :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be
+ forcibly applied to datasource files, and headers in CSV files will be
+ ignored. If the option is set to ``false``, the schema will be
+ validated against all headers in CSV files or the first header in RDD
+ if the ``header`` option is set to ``true``. Field names in the schema
+ and column names in CSV headers are checked by their positions
+ taking into account ``spark.sql.caseSensitive``. If None is set,
+ ``true`` is used by default. Though the default value is ``true``,
+ it is recommended to disable the ``enforceSchema`` option
+ to avoid incorrect results.
:param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from
values being read should be skipped. If None is set, it
uses the default value, ``false``.
@@ -621,13 +636,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
:param mode: allows a mode for dealing with corrupt records during parsing. If None is
set, it uses the default value, ``PERMISSIVE``.
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record, and puts the malformed string into a field configured by \
- ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \
- a string type field named ``columnNameOfCorruptRecord`` in an \
- user-defined schema. If a schema does not have the field, it drops corrupt \
- records during parsing. When a length of parsed CSV tokens is shorter than \
- an expected length of a schema, it sets `null` for extra fields.
+ * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ fields to ``null``. To keep corrupt records, an user can set a string type \
+ field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
+ schema does not have the field, it drops corrupt records during parsing. \
+ A record with less/more tokens than schema is not a corrupted record to CSV. \
+ When it meets a record having fewer tokens than the length of the schema, \
+ sets ``null`` to extra fields. When the record has more tokens than the \
+ length of the schema, it drops extra tokens.
* ``DROPMALFORMED`` : ignores the whole corrupted records.
* ``FAILFAST`` : throws an exception when it meets corrupted records.
@@ -658,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxCharsPerColumn=maxCharsPerColumn,
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
@@ -837,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
self._jwrite = self._jwrite.trigger(jTrigger)
return self
+ @since(2.4)
+ def foreach(self, f):
+ """
+ Sets the output of the streaming query to be processed using the provided writer ``f``.
+ This is often used to write the output of a streaming query to arbitrary storage systems.
+ The processing logic can be specified in two ways.
+
+ #. A **function** that takes a row as input.
+ This is a simple way to express your processing logic. Note that this does
+ not allow you to deduplicate generated data when failures cause reprocessing of
+ some input data. That would require you to specify the processing logic in the next
+ way.
+
+ #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
+ The object can have the following methods.
+
+ * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
+ (for example, open a connection, start a transaction, etc). Additionally, you can
+ use the `partition_id` and `epoch_id` to deduplicate regenerated data
+ (discussed later).
+
+ * ``process(row)``: *Non-optional* method that processes each :class:`Row`.
+
+ * ``close(error)``: *Optional* method that finalizes and cleans up (for example,
+ close connection, commit transaction, etc.) after all rows have been processed.
+
+ The object will be used by Spark in the following way.
+
+ * A single copy of this object is responsible of all the data generated by a
+ single task in a query. In other words, one instance is responsible for
+ processing one partition of the data generated in a distributed manner.
+
+ * This object must be serializable because each task will get a fresh
+ serialized-deserialized copy of the provided object. Hence, it is strongly
+ recommended that any initialization for writing data (e.g. opening a
+ connection or starting a transaction) is done after the `open(...)`
+ method has been called, which signifies that the task is ready to generate data.
+
+ * The lifecycle of the methods are as follows.
+
+ For each partition with ``partition_id``:
+
+ ... For each batch/epoch of streaming data with ``epoch_id``:
+
+ ....... Method ``open(partitionId, epochId)`` is called.
+
+ ....... If ``open(...)`` returns true, for each row in the partition and
+ batch/epoch, method ``process(row)`` is called.
+
+ ....... Method ``close(errorOrNull)`` is called with error (if any) seen while
+ processing rows.
+
+ Important points to note:
+
+ * The `partitionId` and `epochId` can be used to deduplicate generated data when
+ failures cause reprocessing of some input data. This depends on the execution
+ mode of the query. If the streaming query is being executed in the micro-batch
+ mode, then every partition represented by a unique tuple (partition_id, epoch_id)
+ is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
+ to deduplicate and/or transactionally commit data and achieve exactly-once
+ guarantees. However, if the streaming query is being executed in the continuous
+ mode, then this guarantee does not hold and therefore should not be used for
+ deduplication.
+
+ * The ``close()`` method (if exists) will be called if `open()` method exists and
+ returns successfully (irrespective of the return value), except if the Python
+ crashes in the middle.
+
+ .. note:: Evolving.
+
+ >>> # Print every row using a function
+ >>> def print_row(row):
+ ... print(row)
+ ...
+ >>> writer = sdf.writeStream.foreach(print_row)
+ >>> # Print every row using a object with process() method
+ >>> class RowPrinter:
+ ... def open(self, partition_id, epoch_id):
+ ... print("Opened %d, %d" % (partition_id, epoch_id))
+ ... return True
+ ... def process(self, row):
+ ... print(row)
+ ... def close(self, error):
+ ... print("Closed with error: %s" % str(error))
+ ...
+ >>> writer = sdf.writeStream.foreach(RowPrinter())
+ """
+
+ from pyspark.rdd import _wrap_function
+ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+ from pyspark.taskcontext import TaskContext
+
+ if callable(f):
+ # The provided object is a callable function that is supposed to be called on each row.
+ # Construct a function that takes an iterator and calls the provided function on each
+ # row.
+ def func_without_process(_, iterator):
+ for x in iterator:
+ f(x)
+ return iter([])
+
+ func = func_without_process
+
+ else:
+ # The provided object is not a callable function. Then it is expected to have a
+ # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
+ # 'close(error)' methods.
+
+ if not hasattr(f, 'process'):
+ raise Exception("Provided object does not have a 'process' method")
+
+ if not callable(getattr(f, 'process')):
+ raise Exception("Attribute 'process' in provided object is not callable")
+
+ def doesMethodExist(method_name):
+ exists = hasattr(f, method_name)
+ if exists and not callable(getattr(f, method_name)):
+ raise Exception(
+ "Attribute '%s' in provided object is not callable" % method_name)
+ return exists
+
+ open_exists = doesMethodExist('open')
+ close_exists = doesMethodExist('close')
+
+ def func_with_open_process_close(partition_id, iterator):
+ epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId')
+ if epoch_id:
+ epoch_id = int(epoch_id)
+ else:
+ raise Exception("Could not get batch id from TaskContext")
+
+ # Check if the data should be processed
+ should_process = True
+ if open_exists:
+ should_process = f.open(partition_id, epoch_id)
+
+ error = None
+
+ try:
+ if should_process:
+ for x in iterator:
+ f.process(x)
+ except Exception as ex:
+ error = ex
+ finally:
+ if close_exists:
+ f.close(error)
+ if error:
+ raise error
+
+ return iter([])
+
+ func = func_with_open_process_close
+
+ serializer = AutoBatchedSerializer(PickleSerializer())
+ wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer)
+ jForeachWriter = \
+ self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
+ wrapped_func, self._df._jdf.schema())
+ self._jwrite.foreach(jForeachWriter)
+ return self
+
@ignore_unicode_prefix
@since(2.0)
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,
@@ -928,7 +1107,7 @@ def _test():
globs['spark'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 53da7dd45c2f2..4e5fafa77e109 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -32,7 +32,9 @@
import datetime
import array
import ctypes
+import warnings
import py4j
+from contextlib import contextmanager
try:
import xmlrunner
@@ -48,19 +50,26 @@
else:
import unittest
-_have_pandas = False
-_have_old_pandas = False
+from pyspark.util import _exception_message
+
+_pandas_requirement_message = None
try:
- import pandas
- try:
- from pyspark.sql.utils import require_minimum_pandas_version
- require_minimum_pandas_version()
- _have_pandas = True
- except:
- _have_old_pandas = True
-except:
- # No Pandas, but that's okay, we'll skip those tests
- pass
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+except ImportError as e:
+ # If Pandas version requirement is not satisfied, skip related tests.
+ _pandas_requirement_message = _exception_message(e)
+
+_pyarrow_requirement_message = None
+try:
+ from pyspark.sql.utils import require_minimum_pyarrow_version
+ require_minimum_pyarrow_version()
+except ImportError as e:
+ # If Arrow version requirement is not satisfied, skip related tests.
+ _pyarrow_requirement_message = _exception_message(e)
+
+_have_pandas = _pandas_requirement_message is None
+_have_pyarrow = _pyarrow_requirement_message is None
from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
@@ -75,15 +84,6 @@
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
-_have_arrow = False
-try:
- import pyarrow
- _have_arrow = True
-except:
- # No Arrow, but that's okay, we'll skip those tests
- pass
-
-
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
@@ -186,7 +186,38 @@ def __init__(self, key, value):
self.value = value
-class ReusedSQLTestCase(ReusedPySparkTestCase):
+class SQLTestUtils(object):
+ """
+ This util assumes the instance of this to have 'spark' attribute, having a spark session.
+ It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
+ the implementation of this class has 'spark' attribute.
+ """
+
+ @contextmanager
+ def sql_conf(self, pairs):
+ """
+ A convenient context manager to test some configuration specific logic. This sets
+ `value` to the configuration `key` and then restores it back when it exits.
+ """
+ assert isinstance(pairs, dict), "pairs should be a dictionary."
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ keys = pairs.keys()
+ new_values = pairs.values()
+ old_values = [self.spark.conf.get(key, None) for key in keys]
+ for key, new_value in zip(keys, new_values):
+ self.spark.conf.set(key, new_value)
+ try:
+ yield
+ finally:
+ for key, old_value in zip(keys, old_values):
+ if old_value is None:
+ self.spark.conf.unset(key)
+ else:
+ self.spark.conf.set(key, old_value)
+
+
+class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
@@ -626,12 +657,58 @@ def test_non_existed_udaf(self):
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
- def test_multiLine_json(self):
+ def test_linesep_text(self):
+ df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
+ expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
+ Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
+ Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
+ Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df.write.text(tpath, lineSep="!")
+ expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
+ Row(value=u'Tom!30!"My name is Tom"'),
+ Row(value=u'Hyukjin!25!"I am Hyukjin'),
+ Row(value=u''), Row(value=u'I love Spark!"'),
+ Row(value=u'!')]
+ readback = self.spark.read.text(tpath)
+ self.assertEqual(readback.collect(), expected)
+ finally:
+ shutil.rmtree(tpath)
+
+ def test_multiline_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
multiLine=True)
self.assertEqual(people1.collect(), people_array.collect())
+ def test_encoding_json(self):
+ people_array = self.spark.read\
+ .json("python/test_support/sql/people_array_utf16le.json",
+ multiLine=True, encoding="UTF-16LE")
+ expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
+ self.assertEqual(people_array.collect(), expected)
+
+ def test_linesep_json(self):
+ df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
+ expected = [Row(_corrupt_record=None, name=u'Michael'),
+ Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
+ Row(_corrupt_record=u' "age":19}\n', name=None)]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df = self.spark.read.json("python/test_support/sql/people.json")
+ df.write.json(tpath, lineSep="!!")
+ readback = self.spark.read.json(tpath, lineSep="!!")
+ self.assertEqual(readback.collect(), df.collect())
+ finally:
+ shutil.rmtree(tpath)
+
def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True)
@@ -1792,6 +1869,263 @@ def test_query_manager_await_termination(self):
q.stop()
shutil.rmtree(tmpPath)
+ class ForeachWriterTester:
+
+ def __init__(self, spark):
+ self.spark = spark
+
+ def write_open_event(self, partitionId, epochId):
+ self._write_event(
+ self.open_events_dir,
+ {'partition': partitionId, 'epoch': epochId})
+
+ def write_process_event(self, row):
+ self._write_event(self.process_events_dir, {'value': 'text'})
+
+ def write_close_event(self, error):
+ self._write_event(self.close_events_dir, {'error': str(error)})
+
+ def write_input_file(self):
+ self._write_event(self.input_dir, "text")
+
+ def open_events(self):
+ return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
+
+ def process_events(self):
+ return self._read_events(self.process_events_dir, 'value STRING')
+
+ def close_events(self):
+ return self._read_events(self.close_events_dir, 'error STRING')
+
+ def run_streaming_query_on_writer(self, writer, num_files):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ for i in range(num_files):
+ self.write_input_file()
+ sq.processAllAvailable()
+ finally:
+ self.stop_all()
+
+ def assert_invalid_writer(self, writer, msg=None):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ self.write_input_file()
+ sq.processAllAvailable()
+ self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected
+ except Exception as e:
+ if msg:
+ assert(msg in str(e), "%s not in %s" % (msg, str(e)))
+
+ finally:
+ self.stop_all()
+
+ def stop_all(self):
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+
+ def _reset(self):
+ self.input_dir = tempfile.mkdtemp()
+ self.open_events_dir = tempfile.mkdtemp()
+ self.process_events_dir = tempfile.mkdtemp()
+ self.close_events_dir = tempfile.mkdtemp()
+
+ def _read_events(self, dir, json):
+ rows = self.spark.read.schema(json).json(dir).collect()
+ dicts = [row.asDict() for row in rows]
+ return dicts
+
+ def _write_event(self, dir, event):
+ import uuid
+ with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
+ f.write("%s\n" % str(event))
+
+ def __getstate__(self):
+ return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
+
+ def __setstate__(self, state):
+ self.open_events_dir, self.process_events_dir, self.close_events_dir = state
+
+ def test_streaming_foreach_with_simple_function(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ def foreach_func(row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(foreach_func, 2)
+ self.assertEqual(len(tester.process_events()), 2)
+
+ def test_streaming_foreach_with_basic_open_process_close(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ open_events = tester.open_events()
+ self.assertEqual(len(open_events), 2)
+ self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+ self.assertEqual(len(tester.process_events()), 2)
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_with_open_returning_false(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return False
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ self.assertEqual(len(tester.open_events()), 2)
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_without_open_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 2)
+
+ def test_streaming_foreach_without_close_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 2) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_without_open_and_close_methods(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_with_process_throwing_error(self):
+ from pyspark.sql.utils import StreamingQueryException
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ raise Exception("test error")
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ try:
+ tester.run_streaming_query_on_writer(ForeachWriter(), 1)
+ self.fail("bad writer did not fail the query") # this is not expected
+ except StreamingQueryException as e:
+ # TODO: Verify whether original error message is inside the exception
+ pass
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 1)
+ # TODO: Verify whether original error message is inside the exception
+
+ def test_streaming_foreach_with_invalid_writers(self):
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ def func_with_iterator_input(iter):
+ for x in iter:
+ print(x)
+
+ tester.assert_invalid_writer(func_with_iterator_input)
+
+ class WriterWithoutProcess:
+ def open(self, partition):
+ pass
+
+ tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
+
+ class WriterWithNonCallableProcess():
+ process = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableProcess(),
+ "'process' in provided object is not callable")
+
+ class WriterWithNoParamProcess():
+ def process(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamProcess())
+
+ # Abstract class for tests below
+ class WithProcess():
+ def process(self, row):
+ pass
+
+ class WriterWithNonCallableOpen(WithProcess):
+ open = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableOpen(),
+ "'open' in provided object is not callable")
+
+ class WriterWithNoParamOpen(WithProcess):
+ def open(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamOpen())
+
+ class WriterWithNonCallableClose(WithProcess):
+ close = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableClose(),
+ "'close' in provided object is not callable")
+
def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -2150,6 +2484,34 @@ def test_expr(self):
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
self.assertEqual(13, result["length(a)"])
+ def test_repartitionByRange_dataframe(self):
+ schema = StructType([
+ StructField("name", StringType(), True),
+ StructField("age", IntegerType(), True),
+ StructField("height", DoubleType(), True)])
+
+ df1 = self.spark.createDataFrame(
+ [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
+ df2 = self.spark.createDataFrame(
+ [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
+
+ # test repartitionByRange(numPartitions, *cols)
+ df3 = df1.repartitionByRange(2, "name", "age")
+ self.assertEqual(df3.rdd.getNumPartitions(), 2)
+ self.assertEqual(df3.rdd.first(), df2.rdd.first())
+ self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
+
+ # test repartitionByRange(numPartitions, *cols)
+ df4 = df1.repartitionByRange(3, "name", "age")
+ self.assertEqual(df4.rdd.getNumPartitions(), 3)
+ self.assertEqual(df4.rdd.first(), df2.rdd.first())
+ self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
+
+ # test repartitionByRange(*cols)
+ df5 = df1.repartitionByRange("name", "age")
+ self.assertEqual(df5.rdd.first(), df2.rdd.first())
+ self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
+
def test_replace(self):
schema = StructType([
StructField("name", StringType(), True),
@@ -2245,11 +2607,6 @@ def test_replace(self):
.replace(False, True).first())
self.assertTupleEqual(row, (True, True))
- # replace list while value is not given (default to None)
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
- self.assertTupleEqual(row, (None, 10, 80.0))
-
# replace string with None and then drop None rows
row = self.spark.createDataFrame(
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
@@ -2285,6 +2642,12 @@ def test_replace(self):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
+ with self.assertRaisesRegexp(
+ TypeError,
+ 'value argument is required when to_replace is not a dictionary.'):
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
+
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
@@ -2410,17 +2773,13 @@ def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")
- try:
- self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
+ with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
- self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
- finally:
- # We should unset this. Otherwise, other tests are affected.
- self.spark.conf.unset("spark.sql.crossJoin.enabled")
# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
@@ -2453,6 +2812,17 @@ def test_conf(self):
spark.conf.unset("bogo")
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
+ self.assertEqual(spark.conf.get("hyukjin", None), None)
+
+ # This returns 'STATIC' because it's the default value of
+ # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in
+ # `spark.conf.get` is unset.
+ self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC")
+
+ # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but
+ # `defaultValue` in `spark.conf.get` is set to None.
+ self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None)
+
def test_current_database(self):
spark = self.spark
spark.catalog._reset()
@@ -2794,7 +3164,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
def _to_pandas(self):
from datetime import datetime, date
- import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
@@ -2807,7 +3176,7 @@ def _to_pandas(self):
df = self.spark.createDataFrame(data, schema)
return df.toPandas()
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas(self):
import numpy as np
pdf = self._to_pandas()
@@ -2819,13 +3188,13 @@ def test_to_pandas(self):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')
- @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
- def test_to_pandas_old(self):
+ @unittest.skipIf(_have_pandas, "Required Pandas was found.")
+ def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas_avoid_astype(self):
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
@@ -2843,7 +3212,7 @@ def test_create_dataframe_from_array_of_long(self):
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
@@ -2858,19 +3227,147 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
- @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
- def test_create_dataframe_from_old_pandas(self):
- import pandas as pd
- from datetime import datetime
- pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
- "d": [pd.Timestamp.now().date()]})
+ @unittest.skipIf(_have_pandas, "Required Pandas was found.")
+ def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
- with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
+ with self.assertRaisesRegexp(
+ ImportError,
+ "(Pandas >= .* must be installed|No module named '?pandas'?)"):
+ import pandas as pd
+ from datetime import datetime
+ pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
+ "d": [pd.Timestamp.now().date()]})
self.spark.createDataFrame(pdf)
+ # Regression test for SPARK-23360
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
+ def test_create_dateframe_from_pandas_with_dst(self):
+ import pandas as pd
+ from datetime import datetime
+
+ pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
+
+ df = self.spark.createDataFrame(pdf)
+ self.assertPandasEqual(pdf, df.toPandas())
+
+ orig_env_tz = os.environ.get('TZ', None)
+ try:
+ tz = 'America/Los_Angeles'
+ os.environ['TZ'] = tz
+ time.tzset()
+ with self.sql_conf({'spark.sql.session.timeZone': tz}):
+ df = self.spark.createDataFrame(pdf)
+ self.assertPandasEqual(pdf, df.toPandas())
+ finally:
+ del os.environ['TZ']
+ if orig_env_tz is not None:
+ os.environ['TZ'] = orig_env_tz
+ time.tzset()
+
+ def test_sort_with_nulls_order(self):
+ from pyspark.sql import functions
+
+ df = self.spark.createDataFrame(
+ [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
+ [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
+ [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
+ [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
+ [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
+
+ def test_json_sampling_ratio(self):
+ rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+ .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
+ schema = self.spark.read.option('inferSchema', True) \
+ .option('samplingRatio', 0.5) \
+ .json(rdd).schema
+ self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
+
+ def test_csv_sampling_ratio(self):
+ rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+ .map(lambda x: '0.1' if x == 1 else str(x))
+ schema = self.spark.read.option('inferSchema', True)\
+ .csv(rdd, samplingRatio=0.5).schema
+ self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
+
+ def test_checking_csv_header(self):
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+ try:
+ self.spark.createDataFrame([[1, 1000], [2000, 2]])\
+ .toDF('f1', 'f2').write.option("header", "true").csv(path)
+ schema = StructType([
+ StructField('f2', IntegerType(), nullable=True),
+ StructField('f1', IntegerType(), nullable=True)])
+ df = self.spark.read.option('header', 'true').schema(schema)\
+ .csv(path, enforceSchema=False)
+ self.assertRaisesRegexp(
+ Exception,
+ "CSV header does not conform to the schema",
+ lambda: df.collect())
+ finally:
+ shutil.rmtree(path)
+
+ def test_repr_html(self):
+ import re
+ pattern = re.compile(r'^ *\|', re.MULTILINE)
+ df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
+ self.assertEquals(None, df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+ expected1 = """
+ |only showing top 1 row
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
+
class HiveSparkSubmitTests(SparkSubmitTests):
+ @classmethod
+ def setUpClass(cls):
+ # get a SparkContext to check for availability of Hive
+ sc = SparkContext('local[4]', cls.__name__)
+ cls.hive_available = True
+ try:
+ sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ except py4j.protocol.Py4JError:
+ cls.hive_available = False
+ except TypeError:
+ cls.hive_available = False
+ finally:
+ # we don't need this SparkContext for the test
+ sc.stop()
+
+ def setUp(self):
+ super(HiveSparkSubmitTests, self).setUp()
+ if not self.hive_available:
+ self.skipTest("Hive is not available.")
+
def test_hivecontext(self):
# This test checks that HiveContext is using Hive metastore (SPARK-16224).
# It sets a metastore url and checks if there is a derby dir created by
@@ -2900,8 +3397,8 @@ def test_hivecontext(self):
|print(hive_context.sql("show databases").collect())
""")
proc = subprocess.Popen(
- [self.sparkSubmit, "--master", "local-cluster[1,1,1024]",
- "--driver-class-path", hive_site_dir, script],
+ self.sparkSubmit + ["--master", "local-cluster[1,1,1024]",
+ "--driver-class-path", hive_site_dir, script],
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
@@ -2925,6 +3422,69 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()
+class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
+ # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
+ # static and immutable. This can't be set or unset, for example, via `spark.conf`.
+
+ @classmethod
+ def setUpClass(cls):
+ import glob
+ from pyspark.find_spark_home import _find_spark_home
+
+ SPARK_HOME = _find_spark_home()
+ filename_pattern = (
+ "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+ "TestQueryExecutionListener.class")
+ cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
+
+ if cls.has_listener:
+ # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
+ cls.spark = SparkSession.builder \
+ .master("local[4]") \
+ .appName(cls.__name__) \
+ .config(
+ "spark.sql.queryExecutionListeners",
+ "org.apache.spark.sql.TestQueryExecutionListener") \
+ .getOrCreate()
+
+ def setUp(self):
+ if not self.has_listener:
+ raise self.skipTest(
+ "'org.apache.spark.sql.TestQueryExecutionListener' is not "
+ "available. Will skip the related tests.")
+
+ @classmethod
+ def tearDownClass(cls):
+ if hasattr(cls, "spark"):
+ cls.spark.stop()
+
+ def tearDown(self):
+ self.spark._jvm.OnSuccessCall.clear()
+
+ def test_query_execution_listener_on_collect(self):
+ self.assertFalse(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should not be called before 'collect'")
+ self.spark.sql("SELECT * FROM range(1)").collect()
+ self.assertTrue(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should be called after 'collect'")
+
+ @unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
+ def test_query_execution_listener_on_collect_with_arrow(self):
+ with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
+ self.assertFalse(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should not be "
+ "called before 'toPandas'")
+ self.spark.sql("SELECT * FROM range(1)").toPandas()
+ self.assertTrue(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should be called after 'toPandas'")
+
+
class SparkSessionTests(PySparkTestCase):
# This test is separate because it's closely related with session's start and stop.
@@ -2980,18 +3540,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ cls.hive_available = True
try:
cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
except py4j.protocol.Py4JError:
- cls.tearDownClass()
- raise unittest.SkipTest("Hive is not available")
+ cls.hive_available = False
except TypeError:
- cls.tearDownClass()
- raise unittest.SkipTest("Hive is not available")
+ cls.hive_available = False
os.unlink(cls.tempdir.name)
- cls.spark = HiveContext._createForTesting(cls.sc)
- cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- cls.df = cls.sc.parallelize(cls.testData).toDF()
+ if cls.hive_available:
+ cls.spark = HiveContext._createForTesting(cls.sc)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
+
+ def setUp(self):
+ if not self.hive_available:
+ self.skipTest("Hive is not available.")
@classmethod
def tearDownClass(cls):
@@ -3383,7 +3947,9 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class ArrowTests(ReusedSQLTestCase):
@classmethod
@@ -3400,6 +3966,8 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.session.timeZone", tz)
cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+ # Disable fallback by default to easily detect the failures.
+ cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
@@ -3435,11 +4003,28 @@ def create_pandas_data_frame(self):
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
return pd.DataFrame(data=data_dict)
- def test_unsupported_datatype(self):
+ def test_toPandas_fallback_enabled(self):
+ import pandas as pd
+
+ with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
+ schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
+ df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
+ with QuietTest(self.sc):
+ with warnings.catch_warnings(record=True) as warns:
+ pdf = df.toPandas()
+ # Catch and check the last UserWarning.
+ user_warns = [
+ warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Attempting non-optimization" in _exception_message(user_warns[-1]))
+ self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
+
+ def test_toPandas_fallback_disabled(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
+ with self.assertRaisesRegexp(Exception, 'Unsupported type'):
df.toPandas()
def test_null_conversion(self):
@@ -3450,12 +4035,11 @@ def test_null_conversion(self):
self.assertTrue(all([c == 1 for c in null_counts]))
def _toPandas_arrow_toggle(self, df):
- self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
- try:
+ with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
pdf = df.toPandas()
- finally:
- self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
pdf_arrow = df.toPandas()
+
return pdf, pdf_arrow
def test_toPandas_arrow_toggle(self):
@@ -3467,16 +4051,17 @@ def test_toPandas_arrow_toggle(self):
def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
- orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
- try:
- timezone = "America/New_York"
- self.spark.conf.set("spark.sql.session.timeZone", timezone)
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
- try:
- pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
- self.assertPandasEqual(pdf_arrow_la, pdf_la)
- finally:
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+
+ timezone = "America/New_York"
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": False,
+ "spark.sql.session.timeZone": timezone}):
+ pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
+ self.assertPandasEqual(pdf_arrow_la, pdf_la)
+
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": True,
+ "spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
@@ -3489,8 +4074,6 @@ def test_toPandas_respect_session_timezone(self):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
- finally:
- self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
@@ -3506,12 +4089,11 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)
def _createDataFrame_toggle(self, pdf, schema=None):
- self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
- try:
+ with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
- finally:
- self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+
return df_no_arrow, df_arrow
def test_createDataFrame_toggle(self):
@@ -3522,18 +4104,18 @@ def test_createDataFrame_toggle(self):
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
pdf = self.create_pandas_data_frame()
- orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
- try:
- timezone = "America/New_York"
- self.spark.conf.set("spark.sql.session.timeZone", timezone)
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
- try:
- df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
- result_la = df_no_arrow_la.collect()
- result_arrow_la = df_arrow_la.collect()
- self.assertEqual(result_la, result_arrow_la)
- finally:
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+ timezone = "America/New_York"
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": False,
+ "spark.sql.session.timeZone": timezone}):
+ df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
+ result_la = df_no_arrow_la.collect()
+ result_arrow_la = df_arrow_la.collect()
+ self.assertEqual(result_la, result_arrow_la)
+
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": True,
+ "spark.sql.session.timeZone": timezone}):
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
result_ny = df_no_arrow_ny.collect()
result_arrow_ny = df_arrow_ny.collect()
@@ -3546,8 +4128,6 @@ def test_createDataFrame_respect_session_timezone(self):
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
- finally:
- self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
@@ -3560,7 +4140,7 @@ def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
wrong_schema = StructType(list(reversed(self.schema)))
with QuietTest(self.sc):
- with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"):
+ with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"):
self.spark.createDataFrame(pdf, schema=wrong_schema)
def test_createDataFrame_with_names(self):
@@ -3585,7 +4165,7 @@ def test_createDataFrame_column_name_encoding(self):
def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
- with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"):
+ with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
def test_createDataFrame_does_not_modify_input(self):
@@ -3640,8 +4220,49 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df.columns)
self.assertEqual(pdf_col_names, df_arrow.columns)
+ def test_createDataFrame_fallback_enabled(self):
+ import pandas as pd
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+ with QuietTest(self.sc):
+ with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
+ with warnings.catch_warnings(record=True) as warns:
+ df = self.spark.createDataFrame(
+ pd.DataFrame([[{u'a': 1}]]), "a: map")
+ # Catch and check the last UserWarning.
+ user_warns = [
+ warn.message for warn in warns if isinstance(warn.message, UserWarning)]
+ self.assertTrue(len(user_warns) > 0)
+ self.assertTrue(
+ "Attempting non-optimization" in _exception_message(user_warns[-1]))
+ self.assertEqual(df.collect(), [Row(a={u'a': 1})])
+
+ def test_createDataFrame_fallback_disabled(self):
+ import pandas as pd
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
+ self.spark.createDataFrame(
+ pd.DataFrame([[{u'a': 1}]]), "a: map")
+
+ # Regression test for SPARK-23314
+ def test_timestamp_dst(self):
+ import pandas as pd
+ # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
+ dt = [datetime.datetime(2015, 11, 1, 0, 30),
+ datetime.datetime(2015, 11, 1, 1, 30),
+ datetime.datetime(2015, 11, 1, 2, 30)]
+ pdf = pd.DataFrame({'time': dt})
+
+ df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
+ df_from_pandas = self.spark.createDataFrame(pdf)
+
+ self.assertPandasEqual(pdf, df_from_python.toPandas())
+ self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+
+
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
@@ -3715,10 +4336,10 @@ def foo(x):
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
- @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR)
+ @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
def foo(x):
return x
- self.assertEqual(foo.returnType, schema)
+ self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
@@ -3755,17 +4376,74 @@ def zero_with_type():
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
- with self.assertRaisesRegexp(ValueError, 'Invalid returnType'):
+ with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
- def foo(k, v):
+ def foo(k, v, w):
return k
+ def test_stopiteration_in_udf(self):
+ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+ from py4j.protocol import Py4JJavaError
+
+ def foo(x):
+ raise StopIteration()
+
+ def foofoo(x, y):
+ raise StopIteration()
+
+ exc_message = "Caught StopIteration thrown from user's code; failing the task"
+ df = self.spark.range(0, 100)
+
+ # plain udf (test for SPARK-23754)
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.withColumn('v', udf(foo)('id')).collect
+ )
+
+ # pandas scalar udf
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.withColumn(
+ 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
+ ).collect
+ )
+
+ # pandas grouped map
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').apply(
+ pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
+ ).collect
+ )
+
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').apply(
+ pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
+ ).collect
+ )
+
+ # pandas grouped agg
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').agg(
+ pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
+ ).collect
+ )
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class ScalarPandasUDFTests(ReusedSQLTestCase):
@classmethod
@@ -3802,7 +4480,7 @@ def random_udf(v):
return random_udf
def test_vectorized_udf_basic(self):
- from pyspark.sql.functions import pandas_udf, col
+ from pyspark.sql.functions import pandas_udf, col, array
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
@@ -3810,7 +4488,8 @@ def test_vectorized_udf_basic(self):
col('id').cast('float').alias('float'),
col('id').cast('double').alias('double'),
col('id').cast('decimal').alias('decimal'),
- col('id').cast('boolean').alias('bool'))
+ col('id').cast('boolean').alias('bool'),
+ array(col('id')).alias('array_long'))
f = lambda x: x
str_f = pandas_udf(f, StringType())
int_f = pandas_udf(f, IntegerType())
@@ -3819,10 +4498,11 @@ def test_vectorized_udf_basic(self):
double_f = pandas_udf(f, DoubleType())
decimal_f = pandas_udf(f, DecimalType())
bool_f = pandas_udf(f, BooleanType())
+ array_long_f = pandas_udf(f, ArrayType(LongType()))
res = df.select(str_f(col('str')), int_f(col('int')),
long_f(col('long')), float_f(col('float')),
double_f(col('double')), decimal_f('decimal'),
- bool_f(col('bool')))
+ bool_f(col('bool')), array_long_f('array_long'))
self.assertEquals(df.collect(), res.collect())
def test_register_nondeterministic_vectorized_udf_basic(self):
@@ -4027,10 +4707,11 @@ def test_vectorized_udf_chained(self):
def test_vectorized_udf_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
- f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
- df.select(f(col('id'))).collect()
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*MapType'):
+ pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
def test_vectorized_udf_return_scalar(self):
from pyspark.sql.functions import pandas_udf, col
@@ -4065,13 +4746,18 @@ def test_vectorized_udf_varargs(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_unsupported_types(self):
- from pyspark.sql.functions import pandas_udf, col
- schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
- df = self.spark.createDataFrame([(None,)], schema=schema)
- f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
+ from pyspark.sql.functions import pandas_udf
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*MapType'):
+ pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
+
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
- df.select(f(col('map'))).collect()
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
+ pandas_udf(lambda x: x, BinaryType())
def test_vectorized_udf_dates(self):
from pyspark.sql.functions import pandas_udf, col
@@ -4173,9 +4859,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
- orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
- self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
- try:
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)
@pandas_udf(returnType=LongType())
@@ -4185,11 +4869,6 @@ def check_records_per_batch(x):
result = df.select(check_records_per_batch(col("id"))).collect()
for (r,) in result:
self.assertTrue(r <= 3)
- finally:
- if orig_value is None:
- self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
- else:
- self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
@@ -4208,30 +4887,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
- orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
- try:
- timezone = "America/New_York"
- self.spark.conf.set("spark.sql.session.timeZone", timezone)
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
- try:
- df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
- .withColumn("internal_value", internal_value(col("timestamp")))
- result_la = df_la.select(col("idx"), col("internal_value")).collect()
- # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
- diff = 3 * 60 * 60 * 1000 * 1000 * 1000
- result_la_corrected = \
- df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
- finally:
- self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
+ timezone = "America/New_York"
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": False,
+ "spark.sql.session.timeZone": timezone}):
+ df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+ .withColumn("internal_value", internal_value(col("timestamp")))
+ result_la = df_la.select(col("idx"), col("internal_value")).collect()
+ # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
+ diff = 3 * 60 * 60 * 1000 * 1000 * 1000
+ result_la_corrected = \
+ df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": True,
+ "spark.sql.session.timeZone": timezone}):
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
- finally:
- self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
@@ -4277,8 +4953,40 @@ def test_register_vectorized_udf_basic(self):
self.assertEquals(expected.collect(), res1.collect())
self.assertEquals(expected.collect(), res2.collect())
+ # Regression test for SPARK-23314
+ def test_timestamp_dst(self):
+ from pyspark.sql.functions import pandas_udf
+ # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
+ dt = [datetime.datetime(2015, 11, 1, 0, 30),
+ datetime.datetime(2015, 11, 1, 1, 30),
+ datetime.datetime(2015, 11, 1, 2, 30)]
+ df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
+ foo_udf = pandas_udf(lambda x: x, 'timestamp')
+ result = df.withColumn('time', foo_udf(df.time))
+ self.assertEquals(df.collect(), result.collect())
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+ @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
+ def test_type_annotation(self):
+ from pyspark.sql.functions import pandas_udf
+ # Regression test to check if type hints can be used. See SPARK-23569.
+ # Note that it throws an error during compilation in lower Python versions if 'exec'
+ # is not used. Also, note that we explicitly use another dictionary to avoid modifications
+ # in the current 'locals()'.
+ #
+ # Hyukjin: I think it's an ugly way to test issues about syntax specific in
+ # higher versions of Python, which we shouldn't encourage. This was the last resort
+ # I could come up with at that time.
+ _locals = {}
+ exec(
+ "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
+ _locals)
+ df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
+ self.assertEqual(df.first()[0], 0)
+
+
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class GroupedMapPandasUDFTests(ReusedSQLTestCase):
@property
@@ -4288,22 +4996,68 @@ def data(self):
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')
- def test_simple(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
- df = self.data
+ def test_supported_types(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+ df = self.data.withColumn("arr", array(col("id")))
- foo_udf = pandas_udf(
+ # Different forms of group map pandas UDF, results of these are the same
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType())),
+ StructField('v1', DoubleType()),
+ StructField('v2', LongType())])
+
+ udf1 = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
- StructType(
- [StructField('id', LongType()),
- StructField('v', IntegerType()),
- StructField('v1', DoubleType()),
- StructField('v2', LongType())]),
+ output_schema,
PandasUDFType.GROUPED_MAP
)
- result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
- expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
+ udf2 = pandas_udf(
+ lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ udf3 = pandas_udf(
+ lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
+ expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
+
+ result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
+ expected2 = expected1
+
+ result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
+ expected3 = expected1
+
+ self.assertPandasEqual(expected1, result1)
+ self.assertPandasEqual(expected2, result2)
+ self.assertPandasEqual(expected3, result3)
+
+ def test_array_type_correct(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+
+ df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType()))])
+
+ udf = pandas_udf(
+ lambda pdf: pdf,
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(udf).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
self.assertPandasEqual(expected, result)
def test_register_grouped_map_udf(self):
@@ -4399,17 +5153,15 @@ def test_datatype_string(self):
def test_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
- df = self.data
-
- foo = pandas_udf(
- lambda pdf: pdf,
- 'id long, v map',
- PandasUDFType.GROUPED_MAP
- )
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
- df.groupby('id').apply(foo).sort('id').toPandas()
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*grouped map Pandas UDF.*MapType'):
+ pandas_udf(
+ lambda pdf: pdf,
+ 'id long, v map',
+ PandasUDFType.GROUPED_MAP)
def test_wrong_args(self):
from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
@@ -4428,26 +5180,121 @@ def test_wrong_args(self):
df.groupby('id').apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
- df.groupby('id').apply(
- pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
+ df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
df.groupby('id').apply(
- pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
- PandasUDFType.SCALAR))
+ pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
def test_unsupported_types(self):
- from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
schema = StructType(
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
- df = self.spark.createDataFrame([(1, None,)], schema=schema)
- f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
- df.groupby('id').apply(f).collect()
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*grouped map Pandas UDF.*MapType'):
+ pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
+
+ schema = StructType(
+ [StructField("id", LongType(), True),
+ StructField("arr_ts", ArrayType(TimestampType()), True)])
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
+ pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
+
+ # Regression test for SPARK-23314
+ def test_timestamp_dst(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
+ dt = [datetime.datetime(2015, 11, 1, 0, 30),
+ datetime.datetime(2015, 11, 1, 1, 30),
+ datetime.datetime(2015, 11, 1, 2, 30)]
+ df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
+ foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
+ result = df.groupby('time').apply(foo_udf).sort('time')
+ self.assertPandasEqual(df.toPandas(), result.toPandas())
+
+ def test_udf_with_key(self):
+ from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ df = self.data
+ pdf = df.toPandas()
+
+ def foo1(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+
+ return pdf.assign(v1=key[0],
+ v2=pdf.v * key[0],
+ v3=pdf.v * pdf.id,
+ v4=pdf.v * pdf.id.mean())
+
+ def foo2(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+ assert type(key[1]) == np.int32
+
+ return pdf.assign(v1=key[0],
+ v2=key[1],
+ v3=pdf.v * key[0],
+ v4=pdf.v + key[1])
+
+ def foo3(key, pdf):
+ assert type(key) == tuple
+ assert len(key) == 0
+ return pdf.assign(v1=pdf.v * pdf.id)
+
+ # v2 is int because numpy.int64 * pd.Series results in pd.Series
+ # v3 is long because pd.Series * pd.Series results in pd.Series
+ udf1 = pandas_udf(
+ foo1,
+ 'id long, v int, v1 long, v2 int, v3 long, v4 double',
+ PandasUDFType.GROUPED_MAP)
+
+ udf2 = pandas_udf(
+ foo2,
+ 'id long, v int, v1 long, v2 int, v3 int, v4 int',
+ PandasUDFType.GROUPED_MAP)
+
+ udf3 = pandas_udf(
+ foo3,
+ 'id long, v int, v1 long',
+ PandasUDFType.GROUPED_MAP)
+
+ # Test groupby column
+ result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+ expected1 = pdf.groupby('id')\
+ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected1, result1)
+
+ # Test groupby expression
+ result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
+ expected2 = pdf.groupby(pdf.id % 2)\
+ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected2, result2)
+
+ # Test complex groupby
+ result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
+ expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
+ .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected3, result3)
+
+ # Test empty groupby
+ result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
+ expected4 = udf3.func((), pdf)
+ self.assertPandasEqual(expected4, result4)
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class GroupedAggPandasUDFTests(ReusedSQLTestCase):
@property
@@ -4509,23 +5356,32 @@ def weighted_mean(v, w):
return weighted_mean
def test_manual(self):
+ from pyspark.sql.functions import pandas_udf, array
+
df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf
-
- result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id')
+ mean_arr_udf = pandas_udf(
+ self.pandas_agg_mean_udf.func,
+ ArrayType(self.pandas_agg_mean_udf.returnType),
+ self.pandas_agg_mean_udf.evalType)
+
+ result1 = df.groupby('id').agg(
+ sum_udf(df.v),
+ mean_udf(df.v),
+ mean_arr_udf(array(df.v))).sort('id')
expected1 = self.spark.createDataFrame(
- [[0, 245.0, 24.5],
- [1, 255.0, 25.5],
- [2, 265.0, 26.5],
- [3, 275.0, 27.5],
- [4, 285.0, 28.5],
- [5, 295.0, 29.5],
- [6, 305.0, 30.5],
- [7, 315.0, 31.5],
- [8, 325.0, 32.5],
- [9, 335.0, 33.5]],
- ['id', 'sum(v)', 'avg(v)'])
+ [[0, 245.0, 24.5, [24.5]],
+ [1, 255.0, 25.5, [25.5]],
+ [2, 265.0, 26.5, [26.5]],
+ [3, 275.0, 27.5, [27.5]],
+ [4, 285.0, 28.5, [28.5]],
+ [5, 295.0, 29.5, [29.5]],
+ [6, 305.0, 30.5, [30.5]],
+ [7, 315.0, 31.5, [31.5]],
+ [8, 325.0, 32.5, [32.5]],
+ [9, 335.0, 33.5, [33.5]]],
+ ['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
@@ -4562,14 +5418,15 @@ def test_basic(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
def test_unsupported_types(self):
- from pyspark.sql.types import ArrayType, DoubleType, MapType
+ from pyspark.sql.types import DoubleType, MapType
from pyspark.sql.functions import pandas_udf, PandasUDFType
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
- @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG)
- def mean_and_std_udf(v):
- return [v.mean(), v.std()]
+ pandas_udf(
+ lambda x: x,
+ ArrayType(ArrayType(TimestampType())),
+ PandasUDFType.GROUPED_AGG)
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@@ -4742,8 +5599,8 @@ def test_complex_groupby(self):
expected2 = df.groupby().agg(sum(df.v))
# groupby one column and one sql expression
- result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v))
- expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v))
+ result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
+ expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
# groupby one python UDF
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
@@ -4846,9 +5703,7 @@ def test_complex_expressions(self):
def test_retain_group_columns(self):
from pyspark.sql.functions import sum, lit, col
- orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
- self.spark.conf.set("spark.sql.retainGroupColumns", False)
- try:
+ with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf
@@ -4856,11 +5711,14 @@ def test_retain_group_columns(self):
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
- finally:
- if orig_value is None:
- self.spark.conf.unset("spark.sql.retainGroupColumns")
- else:
- self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)
+ def test_array_type(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ df = self.data
+
+ array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG)
+ result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
+ self.assertEquals(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
from pyspark.sql.functions import mean
@@ -4887,9 +5745,238 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
+class WindowPandasUDFTests(ReusedSQLTestCase):
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+ return udf(lambda v: v + 1, 'double')
+
+ @property
+ def pandas_scalar_time_two(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ return pandas_udf(lambda v: v * 2, 'double')
+
+ @property
+ def pandas_agg_mean_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def avg(v):
+ return v.mean()
+ return avg
+
+ @property
+ def pandas_agg_max_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def max(v):
+ return v.max()
+ return max
+
+ @property
+ def pandas_agg_min_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def min(v):
+ return v.min()
+ return min
+
+ @property
+ def unbounded_window(self):
+ return Window.partitionBy('id') \
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+
+ @property
+ def ordered_window(self):
+ return Window.partitionBy('id').orderBy('v')
+
+ @property
+ def unpartitioned_window(self):
+ return Window.partitionBy()
+
+ def test_simple(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max
+
+ df = self.data
+ w = self.unbounded_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
+
+ result2 = df.select(mean_udf(df['v']).over(w))
+ expected2 = df.select(mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_multiple_udfs(self):
+ from pyspark.sql.functions import max, min, mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
+ .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
+ .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
+
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
+ .withColumn('max_v', max(df['v']).over(w)) \
+ .withColumn('min_w', min(df['w']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_replace_existing(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('v', mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_mixed_sql(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
+ expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_mixed_udf(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ plus_one = self.python_plus_one
+ time_two = self.pandas_scalar_time_two
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn(
+ 'v2',
+ plus_one(mean_udf(plus_one(df['v'])).over(w)))
+ expected1 = df.withColumn(
+ 'v2',
+ plus_one(mean(plus_one(df['v'])).over(w)))
+
+ result2 = df.withColumn(
+ 'v2',
+ time_two(mean_udf(time_two(df['v'])).over(w)))
+ expected2 = df.withColumn(
+ 'v2',
+ time_two(mean(time_two(df['v'])).over(w)))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_without_partitionBy(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unpartitioned_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('v2', mean(df['v']).over(w))
+
+ result2 = df.select(mean_udf(df['v']).over(w))
+ expected2 = df.select(mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_mixed_sql_and_udf(self):
+ from pyspark.sql.functions import max, min, rank, col
+
+ df = self.data
+ w = self.unbounded_window
+ ow = self.ordered_window
+ max_udf = self.pandas_agg_max_udf
+ min_udf = self.pandas_agg_min_udf
+
+ result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
+ expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
+
+ # Test mixing sql window function and window udf in the same expression
+ result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
+ expected2 = expected1
+
+ # Test chaining sql aggregate function and udf
+ result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+ .withColumn('min_v', min(df['v']).over(w)) \
+ .withColumn('v_diff', col('max_v') - col('min_v')) \
+ .drop('max_v', 'min_v')
+ expected3 = expected1
+
+ # Test mixing sql window function and udf
+ result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+ .withColumn('rank', rank().over(ow))
+ expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
+ .withColumn('rank', rank().over(ow))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+ def test_array_type(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ df = self.data
+ w = self.unbounded_window
+
+ array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG)
+ result1 = df.withColumn('v2', array_udf(df['v']).over(w))
+ self.assertEquals(result1.first()['v2'], [1.0, 2.0])
+
+ def test_invalid_args(self):
+ from pyspark.sql.functions import mean, pandas_udf, PandasUDFType
+
+ df = self.data
+ w = self.unbounded_window
+ ow = self.ordered_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ '.*not supported within a window function'):
+ foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
+ df.withColumn('v2', foo_udf(df['v']).over(w))
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ '.*Only unbounded window frame is supported.*'):
+ df.withColumn('mean_v', mean_udf(df['v']).over(ow))
+
+
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
else:
- unittest.main()
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 093dae5a22e1f..3cd7a2ef115af 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -35,7 +35,6 @@
from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer
-from pyspark.util import _exception_message
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
@@ -290,7 +289,8 @@ def __init__(self, elementType, containsNull=True):
>>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
- assert isinstance(elementType, DataType), "elementType should be DataType"
+ assert isinstance(elementType, DataType),\
+ "elementType %s should be an instance of %s" % (elementType, DataType)
self.elementType = elementType
self.containsNull = containsNull
@@ -344,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
... == MapType(StringType(), FloatType()))
False
"""
- assert isinstance(keyType, DataType), "keyType should be DataType"
- assert isinstance(valueType, DataType), "valueType should be DataType"
+ assert isinstance(keyType, DataType),\
+ "keyType %s should be an instance of %s" % (keyType, DataType)
+ assert isinstance(valueType, DataType),\
+ "valueType %s should be an instance of %s" % (valueType, DataType)
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
@@ -403,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
... == StructField("f2", StringType(), True))
False
"""
- assert isinstance(dataType, DataType), "dataType should be DataType"
- assert isinstance(name, basestring), "field name should be string"
+ assert isinstance(dataType, DataType),\
+ "dataType %s should be an instance of %s" % (dataType, DataType)
+ assert isinstance(name, basestring), "field name %s should be string" % (name)
if not isinstance(name, str):
name = name.encode('utf-8')
self.name = name
@@ -455,9 +458,6 @@ class StructType(DataType):
Iterating a :class:`StructType` will iterate its :class:`StructField`\\s.
A contained :class:`StructField` can be accessed by name or position.
- .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead
- to get a list of field names.
-
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct1["f1"]
StructField(f1,StringType,true)
@@ -755,41 +755,6 @@ def __eq__(self, other):
_FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)")
-_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-
-
-def _ignore_brackets_split(s, separator):
- """
- Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
- given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return
- ["a", "d"].
- """
- parts = []
- buf = ""
- level = 0
- for c in s:
- if c in _BRACKETS.keys():
- level += 1
- buf += c
- elif c in _BRACKETS.values():
- if level == 0:
- raise ValueError("Brackets are not correctly paired: %s" % s)
- level -= 1
- buf += c
- elif c == separator and level > 0:
- buf += c
- elif c == separator:
- parts.append(buf)
- buf = ""
- else:
- buf += c
-
- if len(buf) == 0:
- raise ValueError("The %s cannot be the last char: %s" % (separator, s))
- parts.append(buf)
- return parts
-
-
def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
@@ -1638,6 +1603,8 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
+ if type(dt.elementType) == TimestampType:
+ raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
@@ -1680,6 +1647,8 @@ def from_arrow_type(at):
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
+ if types.is_timestamp(at.value_type):
+ raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
@@ -1694,6 +1663,19 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
+def _check_series_convert_date(series, data_type):
+ """
+ Cast the series to datetime.date if it's a date type, otherwise returns the original series.
+
+ :param series: pandas.Series
+ :param data_type: a Spark data type for the series
+ """
+ if type(data_type) == DateType:
+ return series.dt.date
+ else:
+ return series
+
+
def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
@@ -1704,11 +1686,48 @@ def _check_dataframe_convert_date(pdf, schema):
:param schema: a Spark schema of the pandas.DataFrame
"""
for field in schema:
- if type(field.dataType) == DateType:
- pdf[field.name] = pdf[field.name].dt.date
+ pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
+def _get_local_timezone():
+ """ Get local timezone using pytz with environment variable, or dateutil.
+
+ If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone
+ string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and
+ it reads system configuration to know the system local timezone.
+
+ See also:
+ - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753
+ - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338
+ """
+ import os
+ return os.environ.get('TZ', 'dateutil/:')
+
+
+def _check_series_localize_timestamps(s, timezone):
+ """
+ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
+
+ If the input series is not a timestamp series, then the same series is returned. If the input
+ series is a timestamp series, then a converted series is returned.
+
+ :param s: pandas.Series
+ :param timezone: the timezone to convert. if None then use local timezone
+ :return pandas.Series that have been converted to tz-naive
+ """
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
+ from pandas.api.types import is_datetime64tz_dtype
+ tz = timezone or _get_local_timezone()
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if is_datetime64tz_dtype(s.dtype):
+ return s.dt.tz_convert(tz).dt.tz_localize(None)
+ else:
+ return s
+
+
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@@ -1720,12 +1739,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
- from pandas.api.types import is_datetime64tz_dtype
- tz = timezone or 'tzlocal()'
for column, series in pdf.iteritems():
- # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
- if is_datetime64tz_dtype(series.dtype):
- pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
+ pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf
@@ -1744,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone):
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
- tz = timezone or 'tzlocal()'
- return s.dt.tz_localize(tz).dt.tz_convert('UTC')
+ # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
+ # timestamp is during the hour when the clock is adjusted backward during due to
+ # daylight saving time (dst).
+ # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
+ # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
+ # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
+ # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
+ #
+ # Here we explicit choose to use standard time. This matches the default behavior of
+ # pytz.
+ #
+ # Here are some code to help understand this behavior:
+ # >>> import datetime
+ # >>> import pandas as pd
+ # >>> import pytz
+ # >>>
+ # >>> t = datetime.datetime(2015, 11, 1, 1, 30)
+ # >>> ts = pd.Series([t])
+ # >>> tz = pytz.timezone('America/New_York')
+ # >>>
+ # >>> ts.dt.tz_localize(tz, ambiguous=True)
+ # 0 2015-11-01 01:30:00-04:00
+ # dtype: datetime64[ns, America/New_York]
+ # >>>
+ # >>> ts.dt.tz_localize(tz, ambiguous=False)
+ # 0 2015-11-01 01:30:00-05:00
+ # dtype: datetime64[ns, America/New_York]
+ # >>>
+ # >>> str(tz.localize(t))
+ # '2015-11-01 01:30:00-05:00'
+ tz = timezone or _get_local_timezone()
+ return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
@@ -1766,15 +1811,16 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
- from_tz = from_timezone or 'tzlocal()'
- to_tz = to_timezone or 'tzlocal()'
+ from_tz = from_timezone or _get_local_timezone()
+ to_tz = to_timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
- return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None)
- if ts is not pd.NaT else pd.NaT)
+ return s.apply(
+ lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
+ if ts is not pd.NaT else pd.NaT)
else:
return s
@@ -1812,7 +1858,7 @@ def _test():
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 0f759c448b8a7..9dbe49b831cef 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -18,12 +18,14 @@
User-defined function related classes and functions
"""
import functools
+import sys
from pyspark import SparkContext, since
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
-from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
- _parse_datatype_string
+from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
+ to_arrow_type, to_arrow_schema
+from pyspark.util import _get_argspec
__all__ = ["UDFRegistration"]
@@ -41,11 +43,10 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
- import inspect
from pyspark.sql.utils import require_minimum_pyarrow_version
-
require_minimum_pyarrow_version()
- argspec = inspect.getargspec(f)
+
+ argspec = _get_argspec(f)
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
@@ -54,11 +55,11 @@ def _create_udf(f, returnType, evalType):
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
- if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
+ if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
+ and len(argspec.args) not in (1, 2):
raise ValueError(
"Invalid function: pandas_udfs with function type GROUPED_MAP "
- "must take a single arg that is a pandas DataFrame."
- )
+ "must take either one argument (data) or two arguments (key, data).")
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
@@ -112,15 +113,31 @@ def returnType(self):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
- if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
- and not isinstance(self._returnType_placeholder, StructType):
- raise ValueError("Invalid returnType: returnType must be a StructType for "
- "pandas_udf with function type GROUPED_MAP")
- elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \
- and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)):
- raise NotImplementedError(
- "ArrayType, StructType and MapType are not supported with "
- "PandasUDFType.GROUPED_AGG")
+ if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
+ try:
+ to_arrow_type(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with scalar Pandas UDFs: %s is "
+ "not supported" % str(self._returnType_placeholder))
+ elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ if isinstance(self._returnType_placeholder, StructType):
+ try:
+ to_arrow_schema(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with grouped map Pandas UDFs: "
+ "%s is not supported" % str(self._returnType_placeholder))
+ else:
+ raise TypeError("Invalid returnType for grouped map Pandas "
+ "UDFs: returnType must be a StructType.")
+ elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
+ try:
+ to_arrow_type(self._returnType_placeholder)
+ except TypeError:
+ raise NotImplementedError(
+ "Invalid returnType with grouped aggregate Pandas UDFs: "
+ "%s is not supported" % str(self._returnType_placeholder))
return self._returnType_placeholder
@@ -356,7 +373,7 @@ def registerJavaUDAF(self, name, javaClassName):
>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
- >>> df.registerTempTable("df")
+ >>> df.createOrReplaceTempView("df")
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
@@ -379,7 +396,7 @@ def _test():
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
spark.stop()
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 08c34c6dccc5e..45363f089a73d 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -115,18 +115,38 @@ def toJArray(gateway, jtype, arr):
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
+ # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
+ minimum_pandas_version = "0.19.2"
+
from distutils.version import LooseVersion
- import pandas
- if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
- raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
- "however, your version was %s." % pandas.__version__)
+ try:
+ import pandas
+ have_pandas = True
+ except ImportError:
+ have_pandas = False
+ if not have_pandas:
+ raise ImportError("Pandas >= %s must be installed; however, "
+ "it was not found." % minimum_pandas_version)
+ if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
+ raise ImportError("Pandas >= %s must be installed; however, "
+ "your version was %s." % (minimum_pandas_version, pandas.__version__))
def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
+ # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
+ minimum_pyarrow_version = "0.8.0"
+
from distutils.version import LooseVersion
- import pyarrow
- if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
- raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
- "however, your version was %s." % pyarrow.__version__)
+ try:
+ import pyarrow
+ have_arrow = True
+ except ImportError:
+ have_arrow = False
+ if not have_arrow:
+ raise ImportError("PyArrow >= %s must be installed; however, "
+ "it was not found." % minimum_pyarrow_version)
+ if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
+ raise ImportError("PyArrow >= %s must be installed; however, "
+ "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 7ce27f9b102c0..d19ced954f04e 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -16,9 +16,11 @@
#
import sys
+if sys.version >= '3':
+ long = int
from pyspark import since, SparkContext
-from pyspark.sql.column import _to_seq, _to_java_column
+from pyspark.sql.column import Column, _to_seq, _to_java_column
__all__ = ["Window", "WindowSpec"]
@@ -42,6 +44,10 @@ class Window(object):
>>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING
>>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3)
+ .. note:: When ordering is not defined, an unbounded window frame (rowFrame,
+ unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined,
+ a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default.
+
.. note:: Experimental
.. versionadded:: 1.4
@@ -120,20 +126,45 @@ def rangeBetween(start, end):
and "5" means the five off after the current row.
We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``,
- and ``Window.currentRow`` to specify special boundary values, rather than using integral
- values directly.
+ ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``,
+ ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow``
+ to specify special boundary values, rather than using integral values directly.
:param start: boundary start, inclusive.
- The frame is unbounded if this is ``Window.unboundedPreceding``, or
+ The frame is unbounded if this is ``Window.unboundedPreceding``,
+ a column returned by ``pyspark.sql.functions.unboundedPreceding``, or
any value less than or equal to max(-sys.maxsize, -9223372036854775808).
:param end: boundary end, inclusive.
- The frame is unbounded if this is ``Window.unboundedFollowing``, or
+ The frame is unbounded if this is ``Window.unboundedFollowing``,
+ a column returned by ``pyspark.sql.functions.unboundedFollowing``, or
any value greater than or equal to min(sys.maxsize, 9223372036854775807).
+
+ >>> from pyspark.sql import functions as F, SparkSession, Window
+ >>> spark = SparkSession.builder.getOrCreate()
+ >>> df = spark.createDataFrame(
+ ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"])
+ >>> window = Window.orderBy("id").partitionBy("category").rangeBetween(
+ ... F.currentRow(), F.lit(1))
+ >>> df.withColumn("sum", F.sum("id").over(window)).show()
+ +---+--------+---+
+ | id|category|sum|
+ +---+--------+---+
+ | 1| b| 3|
+ | 2| b| 5|
+ | 3| b| 3|
+ | 1| a| 4|
+ | 1| a| 4|
+ | 2| a| 2|
+ +---+--------+---+
"""
- if start <= Window._PRECEDING_THRESHOLD:
- start = Window.unboundedPreceding
- if end >= Window._FOLLOWING_THRESHOLD:
- end = Window.unboundedFollowing
+ if isinstance(start, (int, long)) and isinstance(end, (int, long)):
+ if start <= Window._PRECEDING_THRESHOLD:
+ start = Window.unboundedPreceding
+ if end >= Window._FOLLOWING_THRESHOLD:
+ end = Window.unboundedFollowing
+ elif isinstance(start, Column) and isinstance(end, Column):
+ start = start._jc
+ end = end._jc
sc = SparkContext._active_spark_context
jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end)
return WindowSpec(jspec)
@@ -208,29 +239,36 @@ def rangeBetween(self, start, end):
and "5" means the five off after the current row.
We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``,
- and ``Window.currentRow`` to specify special boundary values, rather than using integral
- values directly.
+ ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``,
+ ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow``
+ to specify special boundary values, rather than using integral values directly.
:param start: boundary start, inclusive.
- The frame is unbounded if this is ``Window.unboundedPreceding``, or
+ The frame is unbounded if this is ``Window.unboundedPreceding``,
+ a column returned by ``pyspark.sql.functions.unboundedPreceding``, or
any value less than or equal to max(-sys.maxsize, -9223372036854775808).
:param end: boundary end, inclusive.
- The frame is unbounded if this is ``Window.unboundedFollowing``, or
+ The frame is unbounded if this is ``Window.unboundedFollowing``,
+ a column returned by ``pyspark.sql.functions.unboundedFollowing``, or
any value greater than or equal to min(sys.maxsize, 9223372036854775807).
"""
- if start <= Window._PRECEDING_THRESHOLD:
- start = Window.unboundedPreceding
- if end >= Window._FOLLOWING_THRESHOLD:
- end = Window.unboundedFollowing
+ if isinstance(start, (int, long)) and isinstance(end, (int, long)):
+ if start <= Window._PRECEDING_THRESHOLD:
+ start = Window.unboundedPreceding
+ if end >= Window._FOLLOWING_THRESHOLD:
+ end = Window.unboundedFollowing
+ elif isinstance(start, Column) and isinstance(end, Column):
+ start = start._jc
+ end = end._jc
return WindowSpec(self._jspec.rangeBetween(start, end))
def _test():
import doctest
SparkContext('local[4]', 'PythonTest')
- (failure_count, test_count) = doctest.testmod()
+ (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
if failure_count:
- exit(-1)
+ sys.exit(-1)
if __name__ == "__main__":
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 17c34f8a1c54c..dd924ef89868e 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc):
jdstreams = [d._jdstream for d in dstreams]
# change the final serializer to sc.serializer
func = TransformFunction(self._sc,
- lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
+ lambda t, *rdds: transformFunc(rdds),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index fdb9308604489..ed2e0e7d10fa2 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None,
:param topics: list of topic_name to consume.
:param kafkaParams: Additional params for Kafka.
:param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting
- point of the stream.
+ point of the stream (a dictionary mapping `TopicAndPartition` to
+ integers).
:param keyDecoder: A function used to decode key (default is utf8_decoder).
:param valueDecoder: A function used to decode value (default is utf8_decoder).
:param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess
diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py
index b830797f5c0a0..d4ecc215aea99 100644
--- a/python/pyspark/streaming/listener.py
+++ b/python/pyspark/streaming/listener.py
@@ -23,6 +23,12 @@ class StreamingListener(object):
def __init__(self):
pass
+ def onStreamingStarted(self, streamingStarted):
+ """
+ Called when the streaming has been started.
+ """
+ pass
+
def onReceiverStarted(self, receiverStarted):
"""
Called when a receiver has been started
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 5b86c1cb2c390..373784f826677 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -63,7 +63,7 @@ def setUpClass(cls):
class_name = cls.__name__
conf = SparkConf().set("spark.default.parallelism", 1)
cls.sc = SparkContext(appName=class_name, conf=conf)
- cls.sc.setCheckpointDir("/tmp")
+ cls.sc.setCheckpointDir(tempfile.mkdtemp())
@classmethod
def tearDownClass(cls):
@@ -507,6 +507,10 @@ def __init__(self):
self.batchInfosCompleted = []
self.batchInfosStarted = []
self.batchInfosSubmitted = []
+ self.streamingStartedTime = []
+
+ def onStreamingStarted(self, streamingStarted):
+ self.streamingStartedTime.append(streamingStarted.time)
def onBatchSubmitted(self, batchSubmitted):
self.batchInfosSubmitted.append(batchSubmitted.batchInfo())
@@ -530,9 +534,12 @@ def func(dstream):
batchInfosSubmitted = batch_collector.batchInfosSubmitted
batchInfosStarted = batch_collector.batchInfosStarted
batchInfosCompleted = batch_collector.batchInfosCompleted
+ streamingStartedTime = batch_collector.streamingStartedTime
self.wait_for(batchInfosCompleted, 4)
+ self.assertEqual(len(streamingStartedTime), 1)
+
self.assertGreaterEqual(len(batchInfosSubmitted), 4)
for info in batchInfosSubmitted:
self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
@@ -772,6 +779,12 @@ def func(rdds):
self.assertEqual([2, 3, 1], self._take(dstream, 3))
+ def test_transform_pairrdd(self):
+ # This regression test case is for SPARK-17756.
+ dstream = self.ssc.queueStream(
+ [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
+ self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))
+
def test_get_active(self):
self.assertEqual(StreamingContext.getActive(), None)
@@ -1477,8 +1490,8 @@ def search_kafka_assembly_jar():
raise Exception(
("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
"You need to build Spark with "
- "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or "
- "'build/mvn -Pkafka-0-8 package' before running this test.")
+ "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or "
+ "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please "
"remove all but one") % (", ".join(jars)))
@@ -1494,8 +1507,8 @@ def search_flume_assembly_jar():
raise Exception(
("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) +
"You need to build Spark with "
- "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or "
- "'build/mvn -Pflume package' before running this test.")
+ "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or "
+ "'build/mvn -DskipTests -Pflume package' before running this test.")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please "
"remove all but one") % (", ".join(jars)))
@@ -1503,10 +1516,13 @@ def search_flume_assembly_jar():
return jars[0]
-def search_kinesis_asl_assembly_jar():
+def _kinesis_asl_assembly_dir():
SPARK_HOME = os.environ["SPARK_HOME"]
- kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly")
- jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly")
+ return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly")
+
+
+def search_kinesis_asl_assembly_jar():
+ jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly")
if not jars:
return None
elif len(jars) > 1:
@@ -1539,7 +1555,9 @@ def search_kinesis_asl_assembly_jar():
kinesis_jar_present = True
jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar)
- os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % jars
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests,
StreamingListenerTests]
@@ -1569,7 +1587,7 @@ def search_kinesis_asl_assembly_jar():
else:
raise Exception(
("Failed to find Spark Streaming Kinesis assembly jar in %s. "
- % kinesis_asl_assembly_dir) +
+ % _kinesis_asl_assembly_dir()) +
"You need to build Spark with 'build/sbt -Pkinesis-asl "
"assembly/package streaming-kinesis-asl-assembly/assembly'"
"or 'build/mvn -Pkinesis-asl package' before running this test.")
@@ -1580,11 +1598,11 @@ def search_kinesis_asl_assembly_jar():
sys.stderr.write("[Running %s]\n" % (testcase))
tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
if xmlrunner:
- result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests)
+ result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests)
if not result.wasSuccessful():
failed = True
else:
- result = unittest.TextTestRunner(verbosity=3).run(tests)
+ result = unittest.TextTestRunner(verbosity=2).run(tests)
if not result.wasSuccessful():
failed = True
sys.exit(failed)
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index abbbf6eb9394f..b4b9f97feb7ca 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -18,6 +18,9 @@
import time
from datetime import datetime
import traceback
+import sys
+
+from py4j.java_gateway import is_instance_of
from pyspark import SparkContext, RDD
@@ -64,7 +67,14 @@ def call(self, milliseconds, jrdds):
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
if r:
- return r._jrdd
+ # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`.
+ # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return
+ # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`.
+ # See SPARK-17756.
+ if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"):
+ return r._jrdd
+ else:
+ return r.map(lambda x: x)._jrdd
except:
self.failure = traceback.format_exc()
@@ -147,4 +157,4 @@ def rddToFileName(prefix, suffix, timestamp):
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index e5218d9e75e78..63ae1f30e17ca 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -34,6 +34,7 @@ class TaskContext(object):
_partitionId = None
_stageId = None
_taskAttemptId = None
+ _localProperties = None
def __new__(cls):
"""Even if users construct TaskContext instead of using get, give them the singleton."""
@@ -88,3 +89,9 @@ def taskAttemptId(self):
TaskAttemptID.
"""
return self._taskAttemptId
+
+ def getLocalProperty(self, key):
+ """
+ Get a local property set upstream in the driver, or None if it is missing.
+ """
+ return self._localProperties.get(key, None)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 511585763cb01..a4c5fb1db8b37 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -161,6 +161,37 @@ def gen_gs(N, step=1):
self.assertEqual(k, len(vs))
self.assertEqual(list(range(k)), list(vs))
+ def test_stopiteration_is_raised(self):
+
+ def stopit(*args, **kwargs):
+ raise StopIteration()
+
+ def legit_create_combiner(x):
+ return [x]
+
+ def legit_merge_value(x, y):
+ return x.append(y) or x
+
+ def legit_merge_combiners(x, y):
+ return x.extend(y) or x
+
+ data = [(x % 2, x) for x in range(100)]
+
+ # wrong create combiner
+ m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeValues(data)
+
+ # wrong merge value
+ m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeValues(data)
+
+ # wrong merge combiners
+ m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
+
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
@@ -543,6 +574,20 @@ def test_tc_on_driver(self):
tc = TaskContext.get()
self.assertTrue(tc is None)
+ def test_get_local_property(self):
+ """Verify that local properties set on the driver are available in TaskContext."""
+ key = "testkey"
+ value = "testvalue"
+ self.sc.setLocalProperty(key, value)
+ try:
+ rdd = self.sc.parallelize(range(1), 1)
+ prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
+ self.assertEqual(prop1, value)
+ prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
+ self.assertTrue(prop2 is None)
+ finally:
+ self.sc.setLocalProperty(key, None)
+
class RDDTests(ReusedPySparkTestCase):
@@ -1246,6 +1291,35 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)
+ def test_stopiteration_in_user_code(self):
+
+ def stopit(*x):
+ raise StopIteration()
+
+ seq_rdd = self.sc.parallelize(range(10))
+ keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
+ msg = "Caught StopIteration thrown from user's code; failing the task"
+
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg,
+ seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
+
+ # these methods call the user function both in the driver and in the executor
+ # the exception raised is different according to where the StopIteration happens
+ # RuntimeError is raised if in the driver
+ # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ keyed_rdd.reduceByKeyLocally, stopit)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ seq_rdd.aggregate, 0, stopit, lambda *x: 1)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ seq_rdd.aggregate, 0, lambda *x: 1, stopit)
+
class ProfilerTests(PySparkTestCase):
@@ -1951,7 +2025,12 @@ class SparkSubmitTests(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
- self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit")
+ tmp_dir = tempfile.gettempdir()
+ self.sparkSubmit = [
+ os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"),
+ "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ ]
def tearDown(self):
shutil.rmtree(self.programDir)
@@ -2017,7 +2096,7 @@ def test_single_script(self):
|sc = SparkContext()
|print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
""")
- proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
+ proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 4, 6]", out.decode('utf-8'))
@@ -2033,7 +2112,7 @@ def test_script_with_local_functions(self):
|sc = SparkContext()
|print(sc.parallelize([1, 2, 3]).map(foo).collect())
""")
- proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE)
+ proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[3, 6, 9]", out.decode('utf-8'))
@@ -2051,7 +2130,7 @@ def test_module_dependency(self):
|def myfunc(x):
| return x + 1
""")
- proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script],
+ proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script],
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
@@ -2070,7 +2149,7 @@ def test_module_dependency_on_cluster(self):
|def myfunc(x):
| return x + 1
""")
- proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master",
+ proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master",
"local-cluster[1,1,1024]", script],
stdout=subprocess.PIPE)
out, err = proc.communicate()
@@ -2087,8 +2166,10 @@ def test_package_dependency(self):
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
self.create_spark_package("a:mylib:0.1")
- proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
- "file:" + self.programDir, script], stdout=subprocess.PIPE)
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, script],
+ stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
@@ -2103,9 +2184,11 @@ def test_package_dependency_on_cluster(self):
|print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
""")
self.create_spark_package("a:mylib:0.1")
- proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
- "file:" + self.programDir, "--master",
- "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE)
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, "--master", "local-cluster[1,1,1024]",
+ script],
+ stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out.decode('utf-8'))
@@ -2124,7 +2207,7 @@ def test_single_script_on_cluster(self):
# this will fail if you have different spark.executor.memory
# in conf/spark-defaults.conf
proc = subprocess.Popen(
- [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script],
+ self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script],
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
@@ -2144,7 +2227,7 @@ def test_user_configuration(self):
| sc.stop()
""")
proc = subprocess.Popen(
- [self.sparkSubmit, "--master", "local", script],
+ self.sparkSubmit + ["--master", "local", script],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
out, err = proc.communicate()
@@ -2293,6 +2376,21 @@ def set(self, x=None, other=None, other_x=None):
self.assertEqual(b._x, 2)
+class UtilTests(PySparkTestCase):
+ def test_py4j_exception_message(self):
+ from pyspark.util import _exception_message
+
+ with self.assertRaises(Py4JJavaError) as context:
+ # This attempts java.lang.String(null) which throws an NPE.
+ self.sc._jvm.java.lang.String(None)
+
+ self.assertTrue('NullPointerException' in _exception_message(context.exception))
+
+ def test_parsing_version_string(self):
+ from pyspark.util import VersionUtils
+ self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced"))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
@@ -2342,15 +2440,7 @@ def test_statcounter_array(self):
if __name__ == "__main__":
from pyspark.tests import *
- if not _have_scipy:
- print("NOTE: Skipping SciPy tests as it does not seem to be installed")
- if not _have_numpy:
- print("NOTE: Skipping NumPy tests as it does not seem to be installed")
if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
else:
- unittest.main()
- if not _have_scipy:
- print("NOTE: SciPy tests were skipped as it does not seem to be installed")
- if not _have_numpy:
- print("NOTE: NumPy tests were skipped as it does not seem to be installed")
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index e5d332ce54429..f015542c8799d 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -16,6 +16,11 @@
# limitations under the License.
#
+import re
+import sys
+import inspect
+from py4j.protocol import Py4JJavaError
+
__all__ = []
@@ -33,13 +38,76 @@ def _exception_message(excp):
>>> msg == _exception_message(excp)
True
"""
+ if isinstance(excp, Py4JJavaError):
+ # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message'
+ # attribute in Python 2. We should call 'str' function on this exception in general but
+ # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work
+ # around by the direct call, '__str__()'. Please see SPARK-23517.
+ return excp.__str__()
if hasattr(excp, "message"):
return excp.message
return str(excp)
+def _get_argspec(f):
+ """
+ Get argspec of a function. Supports both Python 2 and Python 3.
+ """
+ if sys.version_info[0] < 3:
+ argspec = inspect.getargspec(f)
+ else:
+ # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
+ # See SPARK-23569.
+ argspec = inspect.getfullargspec(f)
+ return argspec
+
+
+class VersionUtils(object):
+ """
+ Provides utility method to determine Spark versions with given input string.
+ """
+ @staticmethod
+ def majorMinorVersion(sparkVersion):
+ """
+ Given a Spark version string, return the (major version number, minor version number).
+ E.g., for 2.0.1-SNAPSHOT, return (2, 0).
+
+ >>> sparkVersion = "2.4.0"
+ >>> VersionUtils.majorMinorVersion(sparkVersion)
+ (2, 4)
+ >>> sparkVersion = "2.3.0-SNAPSHOT"
+ >>> VersionUtils.majorMinorVersion(sparkVersion)
+ (2, 3)
+
+ """
+ m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion)
+ if m is not None:
+ return (int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion +
+ " version string, but it could not find the major and minor" +
+ " version numbers.")
+
+
+def fail_on_stopiteration(f):
+ """
+ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
+ prevents silent loss of data when 'f' is used in a for loop in Spark code
+ """
+ def wrapper(*args, **kwargs):
+ try:
+ return f(*args, **kwargs)
+ except StopIteration as exc:
+ raise RuntimeError(
+ "Caught StopIteration thrown from user's code; failing the task",
+ exc
+ )
+
+ return wrapper
+
+
if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
if failure_count:
- exit(-1)
+ sys.exit(-1)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 121b3dd1aeec9..38fe2ef06eac5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,6 +27,7 @@
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.java_gateway import do_server_auth
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
@@ -34,6 +35,7 @@
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
+from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -80,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
- raise TypeError("Return type of the user-defined functon should be "
+ raise TypeError("Return type of the user-defined function should be "
"Pandas.Series, but is {}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
@@ -90,11 +92,16 @@ def verify_result_length(*a):
return lambda *a: (verify_result_length(*a), arrow_return_type)
-def wrap_grouped_map_pandas_udf(f, return_type):
- def wrapped(*series):
+def wrap_grouped_map_pandas_udf(f, return_type, argspec):
+ def wrapped(key_series, value_series):
import pandas as pd
- result = f(pd.concat(series, axis=1))
+ if len(argspec.args) == 1:
+ result = f(pd.concat(value_series, axis=1))
+ elif len(argspec.args) == 2:
+ key = tuple(s[0] for s in key_series)
+ result = f(key, pd.concat(value_series, axis=1))
+
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
@@ -116,7 +123,22 @@ def wrap_grouped_agg_pandas_udf(f, return_type):
def wrapped(*series):
import pandas as pd
result = f(*series)
- return pd.Series(result)
+ return pd.Series([result])
+
+ return lambda *a: (wrapped(*a), arrow_return_type)
+
+
+def wrap_window_agg_pandas_udf(f, return_type):
+ # This is similar to grouped_agg_pandas_udf, the only difference
+ # is that window_agg_pandas_udf needs to repeat the return value
+ # to match window length, where grouped_agg_pandas_udf just returns
+ # the scalar value.
+ arrow_return_type = to_arrow_type(return_type)
+
+ def wrapped(*series):
+ import pandas as pd
+ result = f(*series)
+ return pd.Series([result]).repeat(len(series[0]))
return lambda *a: (wrapped(*a), arrow_return_type)
@@ -132,15 +154,22 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)
+ # make sure StopIteration's raised in the user code are not ignored
+ # when they are processed in a for loop, raise them as RuntimeError's instead
+ func = fail_on_stopiteration(row_func)
+
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
- return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
+ return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
- return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
+ argspec = _get_argspec(row_func) # signature was lost when wrapping it
+ return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
- return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
+ return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
+ elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
+ return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
- return arg_offsets, wrap_udf(row_func, return_type)
+ return arg_offsets, wrap_udf(func, return_type)
else:
raise ValueError("Unknown eval type: {}".format(eval_type))
@@ -149,23 +178,42 @@ def read_udfs(pickleSer, infile, eval_type):
num_udfs = read_int(infile)
udfs = {}
call_udf = []
- for i in range(num_udfs):
+ mapper_str = ""
+ if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ # Create function like this:
+ # lambda a: f([a[0]], [a[0], a[1]])
+
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
- udfs['f%d' % i] = udf
- args = ["a[%d]" % o for o in arg_offsets]
- call_udf.append("f%d(%s)" % (i, ", ".join(args)))
- # Create function like this:
- # lambda a: (f0(a0), f1(a1, a2), f2(a3))
- # In the special case of a single UDF this will return a single result rather
- # than a tuple of results; this is the format that the JVM side expects.
- mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
- mapper = eval(mapper_str, udfs)
+ udfs['f'] = udf
+ split_offset = arg_offsets[0] + 1
+ arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
+ arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
+ mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
+ else:
+ # Create function like this:
+ # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
+ # In the special case of a single UDF this will return a single result rather
+ # than a tuple of results; this is the format that the JVM side expects.
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
func = lambda _, it: map(mapper, it)
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
- PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF):
timezone = utf8_deserializer.loads(infile)
ser = ArrowStreamPandasSerializer(timezone)
else:
@@ -180,7 +228,7 @@ def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
- exit(-1)
+ sys.exit(-1)
version = utf8_deserializer.loads(infile)
if version != "%d.%d" % sys.version_info[:2]:
@@ -196,6 +244,12 @@ def main(infile, outfile):
taskContext._partitionId = read_int(infile)
taskContext._attemptNumber = read_int(infile)
taskContext._taskAttemptId = read_long(infile)
+ taskContext._localProperties = dict()
+ for i in range(read_int(infile)):
+ k = utf8_deserializer.loads(infile)
+ v = utf8_deserializer.loads(infile)
+ taskContext._localProperties[k] = v
+
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()
@@ -254,7 +308,7 @@ def process():
# Write the error to stderr if it happened while serializing
print("PySpark worker failed with exception:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
@@ -272,13 +326,15 @@ def process():
else:
# write a different value to tell JVM to not reuse this worker
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
- exit(-1)
+ sys.exit(-1)
if __name__ == '__main__':
- # Read a local port to connect to from stdin
- java_port = int(sys.stdin.readline())
+ # Read information about how to connect back to the JVM from the environment.
+ java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+ auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("rwb", 65536)
+ do_server_auth(sock_file, auth_secret)
main(sock_file, sock_file)
diff --git a/python/run-tests.py b/python/run-tests.py
index 6b41b5ee22814..4c90926cfa350 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -22,16 +22,19 @@
from optparse import OptionParser
import os
import re
+import shutil
import subprocess
import sys
import tempfile
from threading import Thread, Lock
import time
+import uuid
if sys.version < '3':
import Queue
else:
import queue as Queue
from distutils.version import LooseVersion
+from multiprocessing import Manager
# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
@@ -50,6 +53,7 @@ def print_red(text):
print('\033[31m' + text + '\033[0m')
+SKIPPED_TESTS = Manager().dict()
LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
FAILURE_REPORTING_LOCK = Lock()
LOGGER = logging.getLogger()
@@ -66,7 +70,7 @@ def print_red(text):
raise Exception("Cannot find assembly build directory, please build Spark first.")
-def run_individual_python_test(test_name, pyspark_python):
+def run_individual_python_test(target_dir, test_name, pyspark_python):
env = dict(os.environ)
env.update({
'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH,
@@ -75,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python):
'PYSPARK_PYTHON': which(pyspark_python),
'PYSPARK_DRIVER_PYTHON': which(pyspark_python)
})
+
+ # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is
+ # recognized by the tempfile module to override the default system temp directory.
+ tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
+ while os.path.isdir(tmp_dir):
+ tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
+ os.mkdir(tmp_dir)
+ env["TMPDIR"] = tmp_dir
+
+ # Also override the JVM's temp directory by setting driver and executor options.
+ spark_args = [
+ "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "pyspark-shell"
+ ]
+ env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
+
LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
try:
@@ -82,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python):
retcode = subprocess.Popen(
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
stderr=per_test_output, stdout=per_test_output, env=env).wait()
+ shutil.rmtree(tmp_dir, ignore_errors=True)
except:
LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
@@ -109,8 +131,34 @@ def run_individual_python_test(test_name, pyspark_python):
# this code is invoked from a thread other than the main thread.
os._exit(-1)
else:
- per_test_output.close()
- LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
+ skipped_counts = 0
+ try:
+ per_test_output.seek(0)
+ # Here expects skipped test output from unittest when verbosity level is
+ # 2 (or --verbose option is enabled).
+ decoded_lines = map(lambda line: line.decode(), iter(per_test_output))
+ skipped_tests = list(filter(
+ lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line),
+ decoded_lines))
+ skipped_counts = len(skipped_tests)
+ if skipped_counts > 0:
+ key = (pyspark_python, test_name)
+ SKIPPED_TESTS[key] = skipped_tests
+ per_test_output.close()
+ except:
+ import traceback
+ print_red("\nGot an exception while trying to store "
+ "skipped test output:\n%s" % traceback.format_exc())
+ # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
+ # this code is invoked from a thread other than the main thread.
+ os._exit(-1)
+ if skipped_counts != 0:
+ LOGGER.info(
+ "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
+ duration, skipped_counts)
+ else:
+ LOGGER.info(
+ "Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
def get_default_python_executables():
@@ -152,65 +200,17 @@ def parse_opts():
return opts
-def _check_dependencies(python_exec, modules_to_test):
- if "COVERAGE_PROCESS_START" in os.environ:
- # Make sure if coverage is installed.
- try:
- subprocess_check_output(
- [python_exec, "-c", "import coverage"],
- stderr=open(os.devnull, 'w'))
- except:
- print_red("Coverage is not installed in Python executable '%s' "
- "but 'COVERAGE_PROCESS_START' environment variable is set, "
- "exiting." % python_exec)
- sys.exit(-1)
-
- # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and
- # explicitly prints out. See SPARK-23300.
- if pyspark_sql in modules_to_test:
- # TODO(HyukjinKwon): Relocate and deduplicate these version specifications.
- minimum_pyarrow_version = '0.8.0'
- minimum_pandas_version = '0.19.2'
-
- try:
- pyarrow_version = subprocess_check_output(
- [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"],
- universal_newlines=True,
- stderr=open(os.devnull, 'w')).strip()
- if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version):
- LOGGER.info("Will test PyArrow related features against Python executable "
- "'%s' in '%s' module." % (python_exec, pyspark_sql.name))
- else:
- LOGGER.warning(
- "Will skip PyArrow related features against Python executable "
- "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow "
- "%s was found." % (
- python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version))
- except:
- LOGGER.warning(
- "Will skip PyArrow related features against Python executable "
- "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow "
- "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version))
-
- try:
- pandas_version = subprocess_check_output(
- [python_exec, "-c", "import pandas; print(pandas.__version__)"],
- universal_newlines=True,
- stderr=open(os.devnull, 'w')).strip()
- if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version):
- LOGGER.info("Will test Pandas related features against Python executable "
- "'%s' in '%s' module." % (python_exec, pyspark_sql.name))
- else:
- LOGGER.warning(
- "Will skip Pandas related features against Python executable "
- "'%s' in '%s' module. Pandas >= %s is required; however, Pandas "
- "%s was found." % (
- python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version))
- except:
- LOGGER.warning(
- "Will skip Pandas related features against Python executable "
- "'%s' in '%s' module. Pandas >= %s is required; however, Pandas "
- "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version))
+def _check_coverage(python_exec):
+ # Make sure if coverage is installed.
+ try:
+ subprocess_check_output(
+ [python_exec, "-c", "import coverage"],
+ stderr=open(os.devnull, 'w'))
+ except:
+ print_red("Coverage is not installed in Python executable '%s' "
+ "but 'COVERAGE_PROCESS_START' environment variable is set, "
+ "exiting." % python_exec)
+ sys.exit(-1)
def main():
@@ -237,9 +237,10 @@ def main():
task_queue = Queue.PriorityQueue()
for python_exec in python_execs:
- # Check if the python executable has proper dependencies installed to run tests
- # for given modules properly.
- _check_dependencies(python_exec, modules_to_test)
+ # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START'
+ # environmental variable is set.
+ if "COVERAGE_PROCESS_START" in os.environ:
+ _check_coverage(python_exec)
python_implementation = subprocess_check_output(
[python_exec, "-c", "import platform; print(platform.python_implementation())"],
@@ -257,6 +258,11 @@ def main():
priority = 100
task_queue.put((priority, (python_exec, test_goal)))
+ # Create the target directory before starting tasks to avoid races.
+ target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
+ if not os.path.isdir(target_dir):
+ os.mkdir(target_dir)
+
def process_queue(task_queue):
while True:
try:
@@ -264,7 +270,7 @@ def process_queue(task_queue):
except Queue.Empty:
break
try:
- run_individual_python_test(test_goal, python_exec)
+ run_individual_python_test(target_dir, test_goal, python_exec)
finally:
task_queue.task_done()
@@ -281,6 +287,12 @@ def process_queue(task_queue):
total_duration = time.time() - start_time
LOGGER.info("Tests passed in %i seconds", total_duration)
+ for key, lines in sorted(SKIPPED_TESTS.items()):
+ pyspark_python, test_name = key
+ LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python))
+ for line in lines:
+ LOGGER.info(" %s" % line.rstrip())
+
if __name__ == "__main__":
main()
diff --git a/python/setup.py b/python/setup.py
index 251d4526d4dd0..d309e0564530a 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -26,7 +26,7 @@
if sys.version_info < (2, 7):
print("Python versions prior to 2.7 are not supported for pip installed PySpark.",
file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
try:
exec(open('pyspark/version.py').read())
@@ -98,7 +98,12 @@ def _supports_symlinks():
except:
print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH),
file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
+
+# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
+# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
+_minimum_pandas_version = "0.19.2"
+_minimum_pyarrow_version = "0.8.0"
try:
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
@@ -135,7 +140,7 @@ def _supports_symlinks():
if not os.path.isdir(SCRIPTS_TARGET):
print(incorrect_invocation_message, file=sys.stderr)
- exit(-1)
+ sys.exit(-1)
# Scripts directive requires a list of each script path and does not take wild cards.
script_names = os.listdir(SCRIPTS_TARGET)
@@ -196,12 +201,15 @@ def _supports_symlinks():
'pyspark.examples.src.main.python': ['*.py', '*/*.py']},
scripts=scripts,
license='http://www.apache.org/licenses/LICENSE-2.0',
- install_requires=['py4j==0.10.6'],
+ install_requires=['py4j==0.10.7'],
setup_requires=['pypandoc'],
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
- 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
+ 'sql': [
+ 'pandas>=%s' % _minimum_pandas_version,
+ 'pyarrow>=%s' % _minimum_pyarrow_version,
+ ]
},
classifiers=[
'Development Status :: 5 - Production/Stable',
diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json
new file mode 100644
index 0000000000000..9c657fa30ac9c
Binary files /dev/null and b/python/test_support/sql/people_array_utf16le.json differ
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 127f67329f266..4dc399827ffed 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -17,12 +17,10 @@
package org.apache.spark.repl
-import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException}
-import java.net.{HttpURLConnection, URI, URL, URLEncoder}
+import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream}
+import java.net.{URI, URL, URLEncoder}
import java.nio.channels.Channels
-import scala.util.control.NonFatal
-
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.xbean.asm5._
import org.apache.xbean.asm5.Opcodes._
@@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
-import org.apache.spark.util.{ParentClassLoader, Utils}
+import org.apache.spark.util.ParentClassLoader
/**
- * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes
- * defined by the interpreter when the REPL is used. Allows the user to specify if user class path
- * should be first. This class loader delegates getting/finding resources to parent loader, which
- * makes sense until REPL never provide resource dynamically.
+ * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load
+ * classes defined by the interpreter when the REPL is used. Allows the user to specify if user
+ * class path should be first. This class loader delegates getting/finding resources to parent
+ * loader, which makes sense until REPL never provide resource dynamically.
*
* Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or
* the load failed, that it will call the overridden `findClass` function. To avoid the potential
@@ -60,7 +58,6 @@ class ExecutorClassLoader(
private val fetchFn: (String) => InputStream = uri.getScheme() match {
case "spark" => getClassFileInputStreamFromSparkRPC
- case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer
case _ =>
val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf))
getClassFileInputStreamFromFileSystem(fileSystem)
@@ -113,42 +110,6 @@ class ExecutorClassLoader(
}
}
- private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
- val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
- val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
- val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
- newuri.toURL
- } else {
- new URL(classUri + "/" + urlEncode(pathInDirectory))
- }
- val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
- SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
- // Set the connection timeouts (for testing purposes)
- if (httpUrlConnectionTimeoutMillis != -1) {
- connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
- connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
- }
- connection.connect()
- try {
- if (connection.getResponseCode != 200) {
- // Close the error stream so that the connection is eligible for re-use
- try {
- connection.getErrorStream.close()
- } catch {
- case ioe: IOException =>
- logError("Exception while closing error stream", ioe)
- }
- throw new ClassNotFoundException(s"Class file not found at URL $url")
- } else {
- connection.getInputStream
- }
- } catch {
- case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
- connection.disconnect()
- throw e
- }
- }
-
private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)(
pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala
index cc76a703bdf8f..e4ddcef9772e4 100644
--- a/repl/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala
@@ -44,6 +44,7 @@ object Main extends Logging {
var interp: SparkILoop = _
private var hasErrors = false
+ private var isShellSession = false
private def scalaOptionError(msg: String): Unit = {
hasErrors = true
@@ -53,6 +54,7 @@ object Main extends Logging {
}
def main(args: Array[String]) {
+ isShellSession = true
doMain(args, new SparkILoop)
}
@@ -79,44 +81,50 @@ object Main extends Logging {
}
def createSparkSession(): SparkSession = {
- val execUri = System.getenv("SPARK_EXECUTOR_URI")
- conf.setIfMissing("spark.app.name", "Spark shell")
- // SparkContext will detect this configuration and register it with the RpcEnv's
- // file server, setting spark.repl.class.uri to the actual URI for executors to
- // use. This is sort of ugly but since executors are started as part of SparkContext
- // initialization in certain cases, there's an initialization order issue that prevents
- // this from being set after SparkContext is instantiated.
- conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
- if (execUri != null) {
- conf.set("spark.executor.uri", execUri)
- }
- if (System.getenv("SPARK_HOME") != null) {
- conf.setSparkHome(System.getenv("SPARK_HOME"))
- }
+ try {
+ val execUri = System.getenv("SPARK_EXECUTOR_URI")
+ conf.setIfMissing("spark.app.name", "Spark shell")
+ // SparkContext will detect this configuration and register it with the RpcEnv's
+ // file server, setting spark.repl.class.uri to the actual URI for executors to
+ // use. This is sort of ugly but since executors are started as part of SparkContext
+ // initialization in certain cases, there's an initialization order issue that prevents
+ // this from being set after SparkContext is instantiated.
+ conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
+ if (execUri != null) {
+ conf.set("spark.executor.uri", execUri)
+ }
+ if (System.getenv("SPARK_HOME") != null) {
+ conf.setSparkHome(System.getenv("SPARK_HOME"))
+ }
- val builder = SparkSession.builder.config(conf)
- if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") {
- if (SparkSession.hiveClassesArePresent) {
- // In the case that the property is not set at all, builder's config
- // does not have this value set to 'hive' yet. The original default
- // behavior is that when there are hive classes, we use hive catalog.
- sparkSession = builder.enableHiveSupport().getOrCreate()
- logInfo("Created Spark session with Hive support")
+ val builder = SparkSession.builder.config(conf)
+ if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") {
+ if (SparkSession.hiveClassesArePresent) {
+ // In the case that the property is not set at all, builder's config
+ // does not have this value set to 'hive' yet. The original default
+ // behavior is that when there are hive classes, we use hive catalog.
+ sparkSession = builder.enableHiveSupport().getOrCreate()
+ logInfo("Created Spark session with Hive support")
+ } else {
+ // Need to change it back to 'in-memory' if no hive classes are found
+ // in the case that the property is set to hive in spark-defaults.conf
+ builder.config(CATALOG_IMPLEMENTATION.key, "in-memory")
+ sparkSession = builder.getOrCreate()
+ logInfo("Created Spark session")
+ }
} else {
- // Need to change it back to 'in-memory' if no hive classes are found
- // in the case that the property is set to hive in spark-defaults.conf
- builder.config(CATALOG_IMPLEMENTATION.key, "in-memory")
+ // In the case that the property is set but not to 'hive', the internal
+ // default is 'in-memory'. So the sparkSession will use in-memory catalog.
sparkSession = builder.getOrCreate()
logInfo("Created Spark session")
}
- } else {
- // In the case that the property is set but not to 'hive', the internal
- // default is 'in-memory'. So the sparkSession will use in-memory catalog.
- sparkSession = builder.getOrCreate()
- logInfo("Created Spark session")
+ sparkContext = sparkSession.sparkContext
+ sparkSession
+ } catch {
+ case e: Exception if isShellSession =>
+ logError("Failed to initialize Spark session.", e)
+ sys.exit(1)
}
- sparkContext = sparkSession.sparkContext
- sparkSession
}
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala
index ec3d790255ad3..d49e0fd85229f 100644
--- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala
@@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite {
"""
|val timeout = 60000 // 60 seconds
|val start = System.currentTimeMillis
- |while(sc.getExecutorStorageStatus.size != 3 &&
+ |while(sc.statusTracker.getExecutorInfos.size != 3 &&
| (System.currentTimeMillis - start) < timeout) {
| Thread.sleep(10)
|}
@@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite {
|case class Foo(i: Int)
|val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2)
|ret.count()
- |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum
+ |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
- assertContains("res: Int = 20", output)
+ assertContains("res: Int = 10", output)
}
test("should clone and clean line object in ClosureCleaner") {
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index a62f271273465..a6dd47a6b7d95 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -77,6 +77,12 @@
+
+ com.squareup.okhttp3
+ okhttp
+ 3.8.1
+
+
org.mockitomockito-core
@@ -84,9 +90,9 @@
- com.squareup.okhttp3
- okhttp
- 3.8.1
+ org.jmock
+ jmock-junit4
+ test
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index 471196ac0e3f6..bf33179ae3dab 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -54,6 +54,13 @@ private[spark] object Config extends Logging {
.checkValues(Set("Always", "Never", "IfNotPresent"))
.createWithDefault("IfNotPresent")
+ val IMAGE_PULL_SECRETS =
+ ConfigBuilder("spark.kubernetes.container.image.pullSecrets")
+ .doc("Comma separated list of the Kubernetes secrets used " +
+ "to access private image registries.")
+ .stringConf
+ .createOptional
+
val KUBERNETES_AUTH_DRIVER_CONF_PREFIX =
"spark.kubernetes.authenticate.driver"
val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX =
@@ -79,12 +86,24 @@ private[spark] object Config extends Logging {
.stringConf
.createOptional
+ val KUBERNETES_DRIVER_SUBMIT_CHECK =
+ ConfigBuilder("spark.kubernetes.submitInDriver")
+ .internal()
+ .booleanConf
+ .createOptional
+
val KUBERNETES_EXECUTOR_LIMIT_CORES =
ConfigBuilder("spark.kubernetes.executor.limit.cores")
.doc("Specify the hard cpu limit for each executor pod")
.stringConf
.createOptional
+ val KUBERNETES_EXECUTOR_REQUEST_CORES =
+ ConfigBuilder("spark.kubernetes.executor.request.cores")
+ .doc("Specify the cpu request for each executor pod")
+ .stringConf
+ .createOptional
+
val KUBERNETES_DRIVER_POD_NAME =
ConfigBuilder("spark.kubernetes.driver.pod.name")
.doc("Name of the driver pod.")
@@ -98,6 +117,28 @@ private[spark] object Config extends Logging {
.stringConf
.createWithDefault("spark")
+ val KUBERNETES_PYSPARK_PY_FILES =
+ ConfigBuilder("spark.kubernetes.python.pyFiles")
+ .doc("The PyFiles that are distributed via client arguments")
+ .internal()
+ .stringConf
+ .createOptional
+
+ val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE =
+ ConfigBuilder("spark.kubernetes.python.mainAppResource")
+ .doc("The main app resource for pyspark jobs")
+ .internal()
+ .stringConf
+ .createOptional
+
+ val KUBERNETES_PYSPARK_APP_ARGS =
+ ConfigBuilder("spark.kubernetes.python.appArgs")
+ .doc("The app arguments for PySpark Jobs")
+ .internal()
+ .stringConf
+ .createOptional
+
+
val KUBERNETES_ALLOCATION_BATCH_SIZE =
ConfigBuilder("spark.kubernetes.allocation.batch.size")
.doc("Number of pods to launch at once in each round of executor allocation.")
@@ -135,72 +176,40 @@ private[spark] object Config extends Logging {
.checkValue(interval => interval > 0, s"Logging interval must be a positive time value.")
.createWithDefaultString("1s")
- val JARS_DOWNLOAD_LOCATION =
- ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir")
- .doc("Location to download jars to in the driver and executors. When using " +
- "spark-submit, this directory must be empty and will be mounted as an empty directory " +
- "volume on the driver and executor pod.")
- .stringConf
- .createWithDefault("/var/spark-data/spark-jars")
-
- val FILES_DOWNLOAD_LOCATION =
- ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir")
- .doc("Location to download files to in the driver and executors. When using " +
- "spark-submit, this directory must be empty and will be mounted as an empty directory " +
- "volume on the driver and executor pods.")
- .stringConf
- .createWithDefault("/var/spark-data/spark-files")
-
- val INIT_CONTAINER_IMAGE =
- ConfigBuilder("spark.kubernetes.initContainer.image")
- .doc("Image for the driver and executor's init-container for downloading dependencies.")
- .fallbackConf(CONTAINER_IMAGE)
-
- val INIT_CONTAINER_MOUNT_TIMEOUT =
- ConfigBuilder("spark.kubernetes.mountDependencies.timeout")
- .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " +
- "locations into the driver and executor pods.")
- .timeConf(TimeUnit.SECONDS)
- .createWithDefault(300)
-
- val INIT_CONTAINER_MAX_THREAD_POOL_SIZE =
- ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads")
- .doc("Maximum number of remote dependencies to download simultaneously in a driver or " +
- "executor pod.")
- .intConf
- .createWithDefault(5)
-
- val INIT_CONTAINER_REMOTE_JARS =
- ConfigBuilder("spark.kubernetes.initContainer.remoteJars")
- .doc("Comma-separated list of jar URIs to download in the init-container. This is " +
- "calculated from spark.jars.")
- .internal()
- .stringConf
- .createOptional
-
- val INIT_CONTAINER_REMOTE_FILES =
- ConfigBuilder("spark.kubernetes.initContainer.remoteFiles")
- .doc("Comma-separated list of file URIs to download in the init-container. This is " +
- "calculated from spark.files.")
- .internal()
- .stringConf
- .createOptional
-
- val INIT_CONTAINER_CONFIG_MAP_NAME =
- ConfigBuilder("spark.kubernetes.initContainer.configMapName")
- .doc("Name of the config map to use in the init-container that retrieves submitted files " +
- "for the executor.")
- .internal()
- .stringConf
- .createOptional
+ val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL =
+ ConfigBuilder("spark.kubernetes.executor.apiPollingInterval")
+ .doc("Interval between polls against the Kubernetes API server to inspect the " +
+ "state of executors.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .checkValue(interval => interval > 0, s"API server polling interval must be a" +
+ " positive time value.")
+ .createWithDefaultString("30s")
+
+ val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL =
+ ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval")
+ .doc("Interval between successive inspection of executor events sent from the" +
+ " Kubernetes API.")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .checkValue(interval => interval > 0, s"Event processing interval must be a positive" +
+ " time value.")
+ .createWithDefaultString("1s")
- val INIT_CONTAINER_CONFIG_MAP_KEY_CONF =
- ConfigBuilder("spark.kubernetes.initContainer.configMapKey")
- .doc("Key for the entry in the init container config map for submitted files that " +
- "corresponds to the properties for this init-container.")
- .internal()
+ val MEMORY_OVERHEAD_FACTOR =
+ ConfigBuilder("spark.kubernetes.memoryOverheadFactor")
+ .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " +
+ "which in the case of JVM tasks will default to 0.10 and 0.40 for non-JVM jobs")
+ .doubleConf
+ .checkValue(mem_overhead => mem_overhead >= 0 && mem_overhead < 1,
+ "Ensure that memory overhead is a double between 0 --> 1.0")
+ .createWithDefault(0.1)
+
+ val PYSPARK_MAJOR_PYTHON_VERSION =
+ ConfigBuilder("spark.kubernetes.pyspark.pythonversion")
+ .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)")
.stringConf
- .createOptional
+ .checkValue(pv => List("2", "3").contains(pv),
+ "Ensure that major Python version is either Python2 or Python3")
+ .createWithDefault("2")
val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX =
"spark.kubernetes.authenticate.submission"
@@ -210,10 +219,12 @@ private[spark] object Config extends Logging {
val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label."
val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation."
val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets."
+ val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef."
val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label."
val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation."
val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets."
+ val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef."
- val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv."
+ val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv."
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
index 9411956996843..69bd03d1eda6f 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
@@ -63,26 +63,22 @@ private[spark] object Constants {
val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH"
val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_"
val ENV_CLASSPATH = "SPARK_CLASSPATH"
- val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS"
- val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS"
- val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS"
val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS"
- val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY"
- val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR"
+ val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR"
+ // Spark app configs for containers
+ val SPARK_CONF_VOLUME = "spark-conf-volume"
+ val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf"
+ val SPARK_CONF_FILE_NAME = "spark.properties"
+ val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME"
- // Bootstrapping dependencies with the init-container
- val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume"
- val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume"
- val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties"
- val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init"
- val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties"
- val INIT_CONTAINER_PROPERTIES_FILE_PATH =
- s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME"
- val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret"
+ // BINDINGS
+ val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY"
+ val ENV_PYSPARK_FILES = "PYSPARK_FILES"
+ val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS"
+ val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION"
// Miscellaneous
val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc"
val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver"
- val MEMORY_OVERHEAD_FACTOR = 0.10
val MEMORY_OVERHEAD_MIN_MIB = 384L
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala
deleted file mode 100644
index f6a57dfe00171..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder}
-
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-
-/**
- * Bootstraps an init-container for downloading remote dependencies. This is separated out from
- * the init-container steps API because this component can be used to bootstrap init-containers
- * for both the driver and executors.
- */
-private[spark] class InitContainerBootstrap(
- initContainerImage: String,
- imagePullPolicy: String,
- jarsDownloadPath: String,
- filesDownloadPath: String,
- configMapName: String,
- configMapKey: String,
- sparkRole: String,
- sparkConf: SparkConf) {
-
- /**
- * Bootstraps an init-container that downloads dependencies to be used by a main container.
- */
- def bootstrapInitContainer(
- original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = {
- val sharedVolumeMounts = Seq[VolumeMount](
- new VolumeMountBuilder()
- .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME)
- .withMountPath(jarsDownloadPath)
- .build(),
- new VolumeMountBuilder()
- .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME)
- .withMountPath(filesDownloadPath)
- .build())
-
- val customEnvVarKeyPrefix = sparkRole match {
- case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY
- case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv."
- case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role")
- }
- val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map {
- case (key, value) =>
- new EnvVarBuilder()
- .withName(key)
- .withValue(value)
- .build()
- }
-
- val initContainer = new ContainerBuilder(original.initContainer)
- .withName("spark-init")
- .withImage(initContainerImage)
- .withImagePullPolicy(imagePullPolicy)
- .addAllToEnv(customEnvVars.asJava)
- .addNewVolumeMount()
- .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME)
- .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR)
- .endVolumeMount()
- .addToVolumeMounts(sharedVolumeMounts: _*)
- .addToArgs("init")
- .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH)
- .build()
-
- val podWithBasicVolumes = new PodBuilder(original.pod)
- .editSpec()
- .addNewVolume()
- .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME)
- .withNewConfigMap()
- .withName(configMapName)
- .addNewItem()
- .withKey(configMapKey)
- .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME)
- .endItem()
- .endConfigMap()
- .endVolume()
- .addNewVolume()
- .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME)
- .withEmptyDir(new EmptyDirVolumeSource())
- .endVolume()
- .addNewVolume()
- .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME)
- .withEmptyDir(new EmptyDirVolumeSource())
- .endVolume()
- .endSpec()
- .build()
-
- val mainContainer = new ContainerBuilder(original.mainContainer)
- .addToVolumeMounts(sharedVolumeMounts: _*)
- .addNewEnv()
- .withName(ENV_MOUNTED_FILES_DIR)
- .withValue(filesDownloadPath)
- .endEnv()
- .build()
-
- PodWithDetachedInitContainer(
- podWithBasicVolumes,
- initContainer,
- mainContainer)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
new file mode 100644
index 0000000000000..b0ccaa36b01ed
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import scala.collection.mutable
+
+import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit._
+import org.apache.spark.internal.config.ConfigEntry
+
+
+private[spark] sealed trait KubernetesRoleSpecificConf
+
+/*
+ * Structure containing metadata for Kubernetes logic that builds a Spark driver.
+ */
+private[spark] case class KubernetesDriverSpecificConf(
+ mainAppResource: Option[MainAppResource],
+ mainClass: String,
+ appName: String,
+ appArgs: Seq[String]) extends KubernetesRoleSpecificConf
+
+/*
+ * Structure containing metadata for Kubernetes logic that builds a Spark executor.
+ */
+private[spark] case class KubernetesExecutorSpecificConf(
+ executorId: String,
+ driverPod: Pod)
+ extends KubernetesRoleSpecificConf
+
+/**
+ * Structure containing metadata for Kubernetes logic to build Spark pods.
+ */
+private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf](
+ sparkConf: SparkConf,
+ roleSpecificConf: T,
+ appResourceNamePrefix: String,
+ appId: String,
+ roleLabels: Map[String, String],
+ roleAnnotations: Map[String, String],
+ roleSecretNamesToMountPaths: Map[String, String],
+ roleSecretEnvNamesToKeyRefs: Map[String, String],
+ roleEnvs: Map[String, String],
+ sparkFiles: Seq[String]) {
+
+ def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE)
+
+ def sparkJars(): Seq[String] = sparkConf
+ .getOption("spark.jars")
+ .map(str => str.split(",").toSeq)
+ .getOrElse(Seq.empty[String])
+
+ def pyFiles(): Option[String] = sparkConf
+ .get(KUBERNETES_PYSPARK_PY_FILES)
+
+ def pySparkMainResource(): Option[String] = sparkConf
+ .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE)
+
+ def pySparkPythonVersion(): String = sparkConf
+ .get(PYSPARK_MAJOR_PYTHON_VERSION)
+
+ def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY)
+
+ def imagePullSecrets(): Seq[LocalObjectReference] = {
+ sparkConf
+ .get(IMAGE_PULL_SECRETS)
+ .map(_.split(","))
+ .getOrElse(Array.empty[String])
+ .map(_.trim)
+ .map { secret =>
+ new LocalObjectReferenceBuilder().withName(secret).build()
+ }
+ }
+
+ def nodeSelector(): Map[String, String] =
+ KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX)
+
+ def get[T](config: ConfigEntry[T]): T = sparkConf.get(config)
+
+ def get(conf: String): String = sparkConf.get(conf)
+
+ def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue)
+
+ def getOption(key: String): Option[String] = sparkConf.getOption(key)
+}
+
+private[spark] object KubernetesConf {
+ def createDriverConf(
+ sparkConf: SparkConf,
+ appName: String,
+ appResourceNamePrefix: String,
+ appId: String,
+ mainAppResource: Option[MainAppResource],
+ mainClass: String,
+ appArgs: Array[String],
+ maybePyFiles: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = {
+ val sparkConfWithMainAppJar = sparkConf.clone()
+ val additionalFiles = mutable.ArrayBuffer.empty[String]
+ mainAppResource.foreach {
+ case JavaMainAppResource(res) =>
+ val previousJars = sparkConf
+ .getOption("spark.jars")
+ .map(_.split(","))
+ .getOrElse(Array.empty)
+ if (!previousJars.contains(res)) {
+ sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res))
+ }
+ // The function of this outer match is to account for multiple nonJVM
+ // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4
+ case nonJVM: NonJVMResource =>
+ nonJVM match {
+ case PythonMainAppResource(res) =>
+ additionalFiles += res
+ maybePyFiles.foreach{maybePyFiles =>
+ additionalFiles.appendAll(maybePyFiles.split(","))}
+ sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res)
+ }
+ sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4)
+ }
+
+ val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX)
+ require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " +
+ s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " +
+ "operations.")
+ require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " +
+ s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " +
+ "operations.")
+ val driverLabels = driverCustomLabels ++ Map(
+ SPARK_APP_ID_LABEL -> appId,
+ SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE)
+ val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX)
+ val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX)
+ val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX)
+ val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_DRIVER_ENV_PREFIX)
+
+ val sparkFiles = sparkConf
+ .getOption("spark.files")
+ .map(str => str.split(",").toSeq)
+ .getOrElse(Seq.empty[String]) ++ additionalFiles
+
+ KubernetesConf(
+ sparkConfWithMainAppJar,
+ KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs),
+ appResourceNamePrefix,
+ appId,
+ driverLabels,
+ driverAnnotations,
+ driverSecretNamesToMountPaths,
+ driverSecretEnvNamesToKeyRefs,
+ driverEnvs,
+ sparkFiles)
+ }
+
+ def createExecutorConf(
+ sparkConf: SparkConf,
+ executorId: String,
+ appId: String,
+ driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = {
+ val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX)
+ require(
+ !executorCustomLabels.contains(SPARK_APP_ID_LABEL),
+ s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.")
+ require(
+ !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL),
+ s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" +
+ " Spark.")
+ require(
+ !executorCustomLabels.contains(SPARK_ROLE_LABEL),
+ s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.")
+ val executorLabels = Map(
+ SPARK_EXECUTOR_ID_LABEL -> executorId,
+ SPARK_APP_ID_LABEL -> appId,
+ SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++
+ executorCustomLabels
+ val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX)
+ val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX)
+ val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs(
+ sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX)
+ val executorEnv = sparkConf.getExecutorEnv.toMap
+
+ KubernetesConf(
+ sparkConf.clone(),
+ KubernetesExecutorSpecificConf(executorId, driverPod),
+ sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX),
+ appId,
+ executorLabels,
+ executorAnnotations,
+ executorMountSecrets,
+ executorEnvSecrets,
+ executorEnv,
+ Seq.empty[String])
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala
new file mode 100644
index 0000000000000..0c5ae022f4070
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import io.fabric8.kubernetes.api.model.HasMetadata
+
+private[spark] case class KubernetesDriverSpec(
+ pod: SparkPod,
+ driverKubernetesResources: Seq[HasMetadata],
+ systemProperties: Map[String, String])
+
+private[spark] object KubernetesDriverSpec {
+ def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec(
+ SparkPod.initialPod(),
+ Seq.empty,
+ initialProps)
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
index 37331d8bbf9b7..593fb531a004d 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
@@ -16,9 +16,7 @@
*/
package org.apache.spark.deploy.k8s
-import java.io.File
-
-import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder}
+import io.fabric8.kubernetes.api.model.LocalObjectReference
import org.apache.spark.SparkConf
import org.apache.spark.util.Utils
@@ -43,72 +41,23 @@ private[spark] object KubernetesUtils {
opt1.foreach { _ => require(opt2.isEmpty, errMessage) }
}
- /**
- * Append the given init-container to a pod's list of init-containers.
- *
- * @param originalPodSpec original specification of the pod
- * @param initContainer the init-container to add to the pod
- * @return the pod with the init-container added to the list of InitContainers
- */
- def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = {
- new PodBuilder(originalPodSpec)
- .editOrNewSpec()
- .addToInitContainers(initContainer)
- .endSpec()
- .build()
- }
-
/**
* For the given collection of file URIs, resolves them as follows:
- * - File URIs with scheme file:// are resolved to the given download path.
* - File URIs with scheme local:// resolve to just the path of the URI.
* - Otherwise, the URIs are returned as-is.
*/
- def resolveFileUris(
- fileUris: Iterable[String],
- fileDownloadPath: String): Iterable[String] = {
+ def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = {
fileUris.map { uri =>
- resolveFileUri(uri, fileDownloadPath, false)
- }
- }
-
- /**
- * If any file uri has any scheme other than local:// it is mapped as if the file
- * was downloaded to the file download path. Otherwise, it is mapped to the path
- * part of the URI.
- */
- def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = {
- fileUris.map { uri =>
- resolveFileUri(uri, fileDownloadPath, true)
- }
- }
-
- /**
- * Get from a given collection of file URIs the ones that represent remote files.
- */
- def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = {
- uris.filter { uri =>
- val scheme = Utils.resolveURI(uri).getScheme
- scheme != "file" && scheme != "local"
+ resolveFileUri(uri)
}
}
- private def resolveFileUri(
- uri: String,
- fileDownloadPath: String,
- assumesDownloaded: Boolean): String = {
+ def resolveFileUri(uri: String): String = {
val fileUri = Utils.resolveURI(uri)
val fileScheme = Option(fileUri.getScheme).getOrElse("file")
fileScheme match {
- case "local" =>
- fileUri.getPath
- case _ =>
- if (assumesDownloaded || fileScheme == "file") {
- val fileName = new File(fileUri.getPath).getName
- s"$fileDownloadPath/$fileName"
- } else {
- uri
- }
+ case "local" => fileUri.getPath
+ case _ => uri
}
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala
deleted file mode 100644
index c35e7db51d407..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s
-
-import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder}
-
-/**
- * Bootstraps a driver or executor container or an init-container with needed secrets mounted.
- */
-private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) {
-
- /**
- * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod.
- *
- * @param pod the pod into which the secret volumes are being added.
- * @return the updated pod with the secret volumes added.
- */
- def addSecretVolumes(pod: Pod): Pod = {
- var podBuilder = new PodBuilder(pod)
- secretNamesToMountPaths.keys.foreach { name =>
- podBuilder = podBuilder
- .editOrNewSpec()
- .addNewVolume()
- .withName(secretVolumeName(name))
- .withNewSecret()
- .withSecretName(name)
- .endSecret()
- .endVolume()
- .endSpec()
- }
-
- podBuilder.build()
- }
-
- /**
- * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the
- * given container.
- *
- * @param container the container into which the secret volumes are being mounted.
- * @return the updated container with the secrets mounted.
- */
- def mountSecrets(container: Container): Container = {
- var containerBuilder = new ContainerBuilder(container)
- secretNamesToMountPaths.foreach { case (name, path) =>
- containerBuilder = containerBuilder
- .addNewVolumeMount()
- .withName(secretVolumeName(name))
- .withMountPath(path)
- .endVolumeMount()
- }
-
- containerBuilder.build()
- }
-
- private def secretVolumeName(secretName: String): String = {
- secretName + "-volume"
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
similarity index 67%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala
rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
index 0b79f8b12e806..345dd117fd35f 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
@@ -16,16 +16,19 @@
*/
package org.apache.spark.deploy.k8s
-import io.fabric8.kubernetes.api.model.{Container, Pod}
+import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder}
-/**
- * Represents a pod with a detached init-container (not yet added to the pod).
- *
- * @param pod the pod
- * @param initContainer the init-container in the pod
- * @param mainContainer the main container in the pod
- */
-private[spark] case class PodWithDetachedInitContainer(
- pod: Pod,
- initContainer: Container,
- mainContainer: Container)
+private[spark] case class SparkPod(pod: Pod, container: Container)
+
+private[spark] object SparkPod {
+ def initialPod(): SparkPod = {
+ SparkPod(
+ new PodBuilder()
+ .withNewMetadata()
+ .endMetadata()
+ .withNewSpec()
+ .endSpec()
+ .build(),
+ new ContainerBuilder().build())
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala
deleted file mode 100644
index c0f08786b76a1..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s
-
-import java.io.File
-import java.util.concurrent.TimeUnit
-
-import scala.concurrent.{ExecutionContext, Future}
-
-import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf}
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.internal.Logging
-import org.apache.spark.util.{ThreadUtils, Utils}
-
-/**
- * Process that fetches files from a resource staging server and/or arbitrary remote locations.
- *
- * The init-container can handle fetching files from any of those sources, but not all of the
- * sources need to be specified. This allows for composing multiple instances of this container
- * with different configurations for different download sources, or using the same container to
- * download everything at once.
- */
-private[spark] class SparkPodInitContainer(
- sparkConf: SparkConf,
- fileFetcher: FileFetcher) extends Logging {
-
- private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE)
- private implicit val downloadExecutor = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize))
-
- private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION))
- private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION))
-
- private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS)
- private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES)
-
- private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT)
-
- def run(): Unit = {
- logInfo(s"Downloading remote jars: $remoteJars")
- downloadFiles(
- remoteJars,
- jarsDownloadDir,
- s"Remote jars download directory specified at $jarsDownloadDir does not exist " +
- "or is not a directory.")
-
- logInfo(s"Downloading remote files: $remoteFiles")
- downloadFiles(
- remoteFiles,
- filesDownloadDir,
- s"Remote files download directory specified at $filesDownloadDir does not exist " +
- "or is not a directory.")
-
- downloadExecutor.shutdown()
- downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES)
- }
-
- private def downloadFiles(
- filesCommaSeparated: Option[String],
- downloadDir: File,
- errMessage: String): Unit = {
- filesCommaSeparated.foreach { files =>
- require(downloadDir.isDirectory, errMessage)
- Utils.stringToSeq(files).foreach { file =>
- Future[Unit] {
- fileFetcher.fetchFile(file, downloadDir)
- }
- }
- }
- }
-}
-
-private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) {
-
- def fetchFile(uri: String, targetDir: File): Unit = {
- Utils.fetchFile(
- url = uri,
- targetDir = targetDir,
- conf = sparkConf,
- securityMgr = securityManager,
- hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf),
- timestamp = System.currentTimeMillis(),
- useCache = false)
- }
-}
-
-object SparkPodInitContainer extends Logging {
-
- def main(args: Array[String]): Unit = {
- logInfo("Starting init-container to download Spark application dependencies.")
- val sparkConf = new SparkConf(true)
- if (args.nonEmpty) {
- Utils.loadDefaultSparkProperties(sparkConf, args(0))
- }
-
- val securityManager = new SparkSecurityManager(sparkConf)
- val fileFetcher = new FileFetcher(sparkConf, securityManager)
- new SparkPodInitContainer(sparkConf, fileFetcher).run()
- logInfo("Finished downloading application dependencies.")
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
new file mode 100644
index 0000000000000..143dc8a12304e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder}
+
+import org.apache.spark.SparkException
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit._
+import org.apache.spark.internal.config._
+
+private[spark] class BasicDriverFeatureStep(
+ conf: KubernetesConf[KubernetesDriverSpecificConf])
+ extends KubernetesFeatureConfigStep {
+
+ private val driverPodName = conf
+ .get(KUBERNETES_DRIVER_POD_NAME)
+ .getOrElse(s"${conf.appResourceNamePrefix}-driver")
+
+ private val driverContainerImage = conf
+ .get(DRIVER_CONTAINER_IMAGE)
+ .getOrElse(throw new SparkException("Must specify the driver container image"))
+
+ // CPU settings
+ private val driverCpuCores = conf.get("spark.driver.cores", "1")
+ private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES)
+
+ // Memory settings
+ private val driverMemoryMiB = conf.get(DRIVER_MEMORY)
+ private val memoryOverheadMiB = conf
+ .get(DRIVER_MEMORY_OVERHEAD)
+ .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt,
+ MEMORY_OVERHEAD_MIN_MIB))
+ private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val driverCustomEnvs = conf.roleEnvs
+ .toSeq
+ .map { env =>
+ new EnvVarBuilder()
+ .withName(env._1)
+ .withValue(env._2)
+ .build()
+ }
+
+ val driverCpuQuantity = new QuantityBuilder(false)
+ .withAmount(driverCpuCores)
+ .build()
+ val driverMemoryQuantity = new QuantityBuilder(false)
+ .withAmount(s"${driverMemoryWithOverheadMiB}Mi")
+ .build()
+ val maybeCpuLimitQuantity = driverLimitCores.map { limitCores =>
+ ("cpu", new QuantityBuilder(false).withAmount(limitCores).build())
+ }
+
+ val driverContainer = new ContainerBuilder(pod.container)
+ .withName(DRIVER_CONTAINER_NAME)
+ .withImage(driverContainerImage)
+ .withImagePullPolicy(conf.imagePullPolicy())
+ .addAllToEnv(driverCustomEnvs.asJava)
+ .addNewEnv()
+ .withName(ENV_DRIVER_BIND_ADDRESS)
+ .withValueFrom(new EnvVarSourceBuilder()
+ .withNewFieldRef("v1", "status.podIP")
+ .build())
+ .endEnv()
+ .withNewResources()
+ .addToRequests("cpu", driverCpuQuantity)
+ .addToLimits(maybeCpuLimitQuantity.toMap.asJava)
+ .addToRequests("memory", driverMemoryQuantity)
+ .addToLimits("memory", driverMemoryQuantity)
+ .endResources()
+ .build()
+
+ val driverPod = new PodBuilder(pod.pod)
+ .editOrNewMetadata()
+ .withName(driverPodName)
+ .addToLabels(conf.roleLabels.asJava)
+ .addToAnnotations(conf.roleAnnotations.asJava)
+ .endMetadata()
+ .withNewSpec()
+ .withRestartPolicy("Never")
+ .withNodeSelector(conf.nodeSelector().asJava)
+ .addToImagePullSecrets(conf.imagePullSecrets(): _*)
+ .endSpec()
+ .build()
+ SparkPod(driverPod, driverContainer)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = {
+ val additionalProps = mutable.Map(
+ KUBERNETES_DRIVER_POD_NAME.key -> driverPodName,
+ "spark.app.id" -> conf.appId,
+ KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix,
+ KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true")
+
+ val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(
+ conf.sparkJars())
+ val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(
+ conf.sparkFiles)
+ if (resolvedSparkJars.nonEmpty) {
+ additionalProps.put("spark.jars", resolvedSparkJars.mkString(","))
+ }
+ if (resolvedSparkFiles.nonEmpty) {
+ additionalProps.put("spark.files", resolvedSparkFiles.mkString(","))
+ }
+ additionalProps.toMap
+ }
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
new file mode 100644
index 0000000000000..91c54a9776982
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder}
+
+import org.apache.spark.SparkException
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD}
+import org.apache.spark.rpc.RpcEndpointAddress
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.Utils
+
+private[spark] class BasicExecutorFeatureStep(
+ kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf])
+ extends KubernetesFeatureConfigStep {
+
+ // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf
+ private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH)
+ private val executorContainerImage = kubernetesConf
+ .get(EXECUTOR_CONTAINER_IMAGE)
+ .getOrElse(throw new SparkException("Must specify the executor container image"))
+ private val blockManagerPort = kubernetesConf
+ .sparkConf
+ .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT)
+
+ private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix
+
+ private val driverUrl = RpcEndpointAddress(
+ kubernetesConf.get("spark.driver.host"),
+ kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
+ private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY)
+ private val executorMemoryString = kubernetesConf.get(
+ EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString)
+
+ private val memoryOverheadMiB = kubernetesConf
+ .get(EXECUTOR_MEMORY_OVERHEAD)
+ .getOrElse(math.max(
+ (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt,
+ MEMORY_OVERHEAD_MIN_MIB))
+ private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB
+
+ private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1)
+ private val executorCoresRequest =
+ if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) {
+ kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get
+ } else {
+ executorCores.toString
+ }
+ private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES)
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}"
+
+ // hostname must be no longer than 63 characters, so take the last 63 characters of the pod
+ // name as the hostname. This preserves uniqueness since the end of name contains
+ // executorId
+ val hostname = name.substring(Math.max(0, name.length - 63))
+ val executorMemoryQuantity = new QuantityBuilder(false)
+ .withAmount(s"${executorMemoryWithOverhead}Mi")
+ .build()
+ val executorCpuQuantity = new QuantityBuilder(false)
+ .withAmount(executorCoresRequest)
+ .build()
+ val executorExtraClasspathEnv = executorExtraClasspath.map { cp =>
+ new EnvVarBuilder()
+ .withName(ENV_CLASSPATH)
+ .withValue(cp)
+ .build()
+ }
+ val executorExtraJavaOptionsEnv = kubernetesConf
+ .get(EXECUTOR_JAVA_OPTIONS)
+ .map { opts =>
+ val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId,
+ kubernetesConf.roleSpecificConf.executorId)
+ val delimitedOpts = Utils.splitCommandString(subsOpts)
+ delimitedOpts.zipWithIndex.map {
+ case (opt, index) =>
+ new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build()
+ }
+ }.getOrElse(Seq.empty[EnvVar])
+ val executorEnv = (Seq(
+ (ENV_DRIVER_URL, driverUrl),
+ (ENV_EXECUTOR_CORES, executorCores.toString),
+ (ENV_EXECUTOR_MEMORY, executorMemoryString),
+ (ENV_APPLICATION_ID, kubernetesConf.appId),
+ // This is to set the SPARK_CONF_DIR to be /opt/spark/conf
+ (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL),
+ (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++
+ kubernetesConf.roleEnvs)
+ .map(env => new EnvVarBuilder()
+ .withName(env._1)
+ .withValue(env._2)
+ .build()
+ ) ++ Seq(
+ new EnvVarBuilder()
+ .withName(ENV_EXECUTOR_POD_IP)
+ .withValueFrom(new EnvVarSourceBuilder()
+ .withNewFieldRef("v1", "status.podIP")
+ .build())
+ .build()
+ ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq
+ val requiredPorts = Seq(
+ (BLOCK_MANAGER_PORT_NAME, blockManagerPort))
+ .map { case (name, port) =>
+ new ContainerPortBuilder()
+ .withName(name)
+ .withContainerPort(port)
+ .build()
+ }
+
+ val executorContainer = new ContainerBuilder(pod.container)
+ .withName("executor")
+ .withImage(executorContainerImage)
+ .withImagePullPolicy(kubernetesConf.imagePullPolicy())
+ .withNewResources()
+ .addToRequests("memory", executorMemoryQuantity)
+ .addToLimits("memory", executorMemoryQuantity)
+ .addToRequests("cpu", executorCpuQuantity)
+ .endResources()
+ .addAllToEnv(executorEnv.asJava)
+ .withPorts(requiredPorts.asJava)
+ .addToArgs("executor")
+ .build()
+ val containerWithLimitCores = executorLimitCores.map { limitCores =>
+ val executorCpuLimitQuantity = new QuantityBuilder(false)
+ .withAmount(limitCores)
+ .build()
+ new ContainerBuilder(executorContainer)
+ .editResources()
+ .addToLimits("cpu", executorCpuLimitQuantity)
+ .endResources()
+ .build()
+ }.getOrElse(executorContainer)
+ val driverPod = kubernetesConf.roleSpecificConf.driverPod
+ val executorPod = new PodBuilder(pod.pod)
+ .editOrNewMetadata()
+ .withName(name)
+ .withLabels(kubernetesConf.roleLabels.asJava)
+ .withAnnotations(kubernetesConf.roleAnnotations.asJava)
+ .withOwnerReferences()
+ .addNewOwnerReference()
+ .withController(true)
+ .withApiVersion(driverPod.getApiVersion)
+ .withKind(driverPod.getKind)
+ .withName(driverPod.getMetadata.getName)
+ .withUid(driverPod.getMetadata.getUid)
+ .endOwnerReference()
+ .endMetadata()
+ .editOrNewSpec()
+ .withHostname(hostname)
+ .withRestartPolicy("Never")
+ .withNodeSelector(kubernetesConf.nodeSelector().asJava)
+ .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*)
+ .endSpec()
+ .build()
+ SparkPod(executorPod, containerWithLimitCores)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala
new file mode 100644
index 0000000000000..ff5ad6673b309
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.{BaseEncoding, Files}
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+
+private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_])
+ extends KubernetesFeatureConfigStep {
+ // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory.
+ // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive.
+
+ private val maybeMountedOAuthTokenFile = kubernetesConf.getOption(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX")
+ private val maybeMountedClientKeyFile = kubernetesConf.getOption(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX")
+ private val maybeMountedClientCertFile = kubernetesConf.getOption(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX")
+ private val maybeMountedCaCertFile = kubernetesConf.getOption(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX")
+ private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME)
+
+ private val oauthTokenBase64 = kubernetesConf
+ .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX")
+ .map { token =>
+ BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8))
+ }
+
+ private val caCertDataBase64 = safeFileConfToBase64(
+ s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX",
+ "Driver CA cert file")
+ private val clientKeyDataBase64 = safeFileConfToBase64(
+ s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX",
+ "Driver client key file")
+ private val clientCertDataBase64 = safeFileConfToBase64(
+ s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX",
+ "Driver client cert file")
+
+ // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder.
+ private val shouldMountSecret = oauthTokenBase64.isDefined ||
+ caCertDataBase64.isDefined ||
+ clientKeyDataBase64.isDefined ||
+ clientCertDataBase64.isDefined
+
+ private val driverCredentialsSecretName =
+ s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials"
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ if (!shouldMountSecret) {
+ pod.copy(
+ pod = driverServiceAccount.map { account =>
+ new PodBuilder(pod.pod)
+ .editOrNewSpec()
+ .withServiceAccount(account)
+ .withServiceAccountName(account)
+ .endSpec()
+ .build()
+ }.getOrElse(pod.pod))
+ } else {
+ val driverPodWithMountedKubernetesCredentials =
+ new PodBuilder(pod.pod)
+ .editOrNewSpec()
+ .addNewVolume()
+ .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
+ .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret()
+ .endVolume()
+ .endSpec()
+ .build()
+
+ val driverContainerWithMountedSecretVolume =
+ new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
+ .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR)
+ .endVolumeMount()
+ .build()
+ SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume)
+ }
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = {
+ val resolvedMountedOAuthTokenFile = resolveSecretLocation(
+ maybeMountedOAuthTokenFile,
+ oauthTokenBase64,
+ DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH)
+ val resolvedMountedClientKeyFile = resolveSecretLocation(
+ maybeMountedClientKeyFile,
+ clientKeyDataBase64,
+ DRIVER_CREDENTIALS_CLIENT_KEY_PATH)
+ val resolvedMountedClientCertFile = resolveSecretLocation(
+ maybeMountedClientCertFile,
+ clientCertDataBase64,
+ DRIVER_CREDENTIALS_CLIENT_CERT_PATH)
+ val resolvedMountedCaCertFile = resolveSecretLocation(
+ maybeMountedCaCertFile,
+ caCertDataBase64,
+ DRIVER_CREDENTIALS_CA_CERT_PATH)
+
+ val redactedTokens = kubernetesConf.sparkConf.getAll
+ .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX))
+ .toMap
+ .mapValues( _ => "")
+ redactedTokens ++
+ resolvedMountedCaCertFile.map { file =>
+ Map(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" ->
+ file)
+ }.getOrElse(Map.empty) ++
+ resolvedMountedClientKeyFile.map { file =>
+ Map(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" ->
+ file)
+ }.getOrElse(Map.empty) ++
+ resolvedMountedClientCertFile.map { file =>
+ Map(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" ->
+ file)
+ }.getOrElse(Map.empty) ++
+ resolvedMountedOAuthTokenFile.map { file =>
+ Map(
+ s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" ->
+ file)
+ }.getOrElse(Map.empty)
+ }
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
+ if (shouldMountSecret) {
+ Seq(createCredentialsSecret())
+ } else {
+ Seq.empty
+ }
+ }
+
+ private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = {
+ kubernetesConf.getOption(conf)
+ .map(new File(_))
+ .map { file =>
+ require(file.isFile, String.format("%s provided at %s does not exist or is not a file.",
+ fileType, file.getAbsolutePath))
+ BaseEncoding.base64().encode(Files.toByteArray(file))
+ }
+ }
+
+ /**
+ * Resolve a Kubernetes secret data entry from an optional client credential used by the
+ * driver to talk to the Kubernetes API server.
+ *
+ * @param userSpecifiedCredential the optional user-specified client credential.
+ * @param secretName name of the Kubernetes secret storing the client credential.
+ * @return a secret data entry in the form of a map from the secret name to the secret data,
+ * which may be empty if the user-specified credential is empty.
+ */
+ private def resolveSecretData(
+ userSpecifiedCredential: Option[String],
+ secretName: String): Map[String, String] = {
+ userSpecifiedCredential.map { valueBase64 =>
+ Map(secretName -> valueBase64)
+ }.getOrElse(Map.empty[String, String])
+ }
+
+ private def resolveSecretLocation(
+ mountedUserSpecified: Option[String],
+ valueMountedFromSubmitter: Option[String],
+ mountedCanonicalLocation: String): Option[String] = {
+ mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ =>
+ mountedCanonicalLocation
+ })
+ }
+
+ private def createCredentialsSecret(): Secret = {
+ val allSecretData =
+ resolveSecretData(
+ clientKeyDataBase64,
+ DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++
+ resolveSecretData(
+ clientCertDataBase64,
+ DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++
+ resolveSecretData(
+ caCertDataBase64,
+ DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++
+ resolveSecretData(
+ oauthTokenBase64,
+ DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME)
+
+ new SecretBuilder()
+ .withNewMetadata()
+ .withName(driverCredentialsSecretName)
+ .endMetadata()
+ .withData(allSecretData.asJava)
+ .build()
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala
new file mode 100644
index 0000000000000..f2d7bbd08f305
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{Clock, SystemClock}
+
+private[spark] class DriverServiceFeatureStep(
+ kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf],
+ clock: Clock = new SystemClock)
+ extends KubernetesFeatureConfigStep with Logging {
+ import DriverServiceFeatureStep._
+
+ require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty,
+ s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " +
+ "address is managed and set to the driver pod's IP address.")
+ require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty,
+ s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " +
+ "managed via a Kubernetes service.")
+
+ private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX"
+ private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) {
+ preferredServiceName
+ } else {
+ val randomServiceId = clock.getTimeMillis()
+ val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX"
+ logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " +
+ s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " +
+ s"$shorterServiceName as the driver service's name.")
+ shorterServiceName
+ }
+
+ private val driverPort = kubernetesConf.sparkConf.getInt(
+ "spark.driver.port", DEFAULT_DRIVER_PORT)
+ private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt(
+ org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT)
+
+ override def configurePod(pod: SparkPod): SparkPod = pod
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = {
+ val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc"
+ Map(DRIVER_HOST_KEY -> driverHostname,
+ "spark.driver.port" -> driverPort.toString,
+ org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key ->
+ driverBlockManagerPort.toString)
+ }
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
+ val driverService = new ServiceBuilder()
+ .withNewMetadata()
+ .withName(resolvedServiceName)
+ .endMetadata()
+ .withNewSpec()
+ .withClusterIP("None")
+ .withSelector(kubernetesConf.roleLabels.asJava)
+ .addNewPort()
+ .withName(DRIVER_PORT_NAME)
+ .withPort(driverPort)
+ .withNewTargetPort(driverPort)
+ .endPort()
+ .addNewPort()
+ .withName(BLOCK_MANAGER_PORT_NAME)
+ .withPort(driverBlockManagerPort)
+ .withNewTargetPort(driverBlockManagerPort)
+ .endPort()
+ .endSpec()
+ .build()
+ Seq(driverService)
+ }
+}
+
+private[spark] object DriverServiceFeatureStep {
+ val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key
+ val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key
+ val DRIVER_SVC_POSTFIX = "-driver-svc"
+ val MAX_SERVICE_NAME_LENGTH = 63
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala
new file mode 100644
index 0000000000000..03ff7d48420ff
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod}
+
+private[spark] class EnvSecretsFeatureStep(
+ kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ extends KubernetesFeatureConfigStep {
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val addedEnvSecrets = kubernetesConf
+ .roleSecretEnvNamesToKeyRefs
+ .map{ case (envName, keyRef) =>
+ // Keyref parts
+ val keyRefParts = keyRef.split(":")
+ require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.")
+ val name = keyRefParts(0)
+ val key = keyRefParts(1)
+ new EnvVarBuilder()
+ .withName(envName)
+ .withNewValueFrom()
+ .withNewSecretKeyRef()
+ .withKey(key)
+ .withName(name)
+ .endSecretKeyRef()
+ .endValueFrom()
+ .build()
+ }
+
+ val containerWithEnvVars = new ContainerBuilder(pod.container)
+ .addAllToEnv(addedEnvSecrets.toSeq.asJava)
+ .build()
+ SparkPod(pod.pod, containerWithEnvVars)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala
new file mode 100644
index 0000000000000..4c1be3bb13293
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.HasMetadata
+
+import org.apache.spark.deploy.k8s.SparkPod
+
+/**
+ * A collection of functions that together represent a "feature" in pods that are launched for
+ * Spark drivers and executors.
+ */
+private[spark] trait KubernetesFeatureConfigStep {
+
+ /**
+ * Apply modifications on the given pod in accordance to this feature. This can include attaching
+ * volumes, adding environment variables, and adding labels/annotations.
+ *
+ * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod
+ * object. So this is correct:
+ *
+ * {@code val configuredPod = new PodBuilder(pod.pod)
+ * .editSpec()
+ * ...
+ * .build()
+ * val configuredContainer = new ContainerBuilder(pod.container)
+ * ...
+ * .build()
+ * SparkPod(configuredPod, configuredContainer)
+ * }
+ *
+ * This is incorrect:
+ *
+ * {@code val configuredPod = new PodBuilder() // Loses the original state
+ * .editSpec()
+ * ...
+ * .build()
+ * val configuredContainer = new ContainerBuilder() // Loses the original state
+ * ...
+ * .build()
+ * SparkPod(configuredPod, configuredContainer)
+ * }
+ *
+ */
+ def configurePod(pod: SparkPod): SparkPod
+
+ /**
+ * Return any system properties that should be set on the JVM in accordance to this feature.
+ */
+ def getAdditionalPodSystemProperties(): Map[String, String]
+
+ /**
+ * Return any additional Kubernetes resources that should be added to support this feature. Only
+ * applicable when creating the driver in cluster mode.
+ */
+ def getAdditionalKubernetesResources(): Seq[HasMetadata]
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala
new file mode 100644
index 0000000000000..70b307303d149
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.nio.file.Paths
+import java.util.UUID
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+
+private[spark] class LocalDirsFeatureStep(
+ conf: KubernetesConf[_ <: KubernetesRoleSpecificConf],
+ defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}")
+ extends KubernetesFeatureConfigStep {
+
+ // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system
+ // property - we want to instead default to mounting an emptydir volume that doesn't already
+ // exist in the image.
+ // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already
+ // a bit opinionated about YARN and Mesos.
+ private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS"))
+ .orElse(conf.getOption("spark.local.dir"))
+ .getOrElse(defaultLocalDir)
+ .split(",")
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val localDirVolumes = resolvedLocalDirs
+ .zipWithIndex
+ .map { case (localDir, index) =>
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-${index + 1}")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build()
+ }
+ val localDirVolumeMounts = localDirVolumes
+ .zip(resolvedLocalDirs)
+ .map { case (localDirVolume, localDirPath) =>
+ new VolumeMountBuilder()
+ .withName(localDirVolume.getName)
+ .withMountPath(localDirPath)
+ .build()
+ }
+ val podWithLocalDirVolumes = new PodBuilder(pod.pod)
+ .editSpec()
+ .addToVolumes(localDirVolumes: _*)
+ .endSpec()
+ .build()
+ val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container)
+ .addNewEnv()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue(resolvedLocalDirs.mkString(","))
+ .endEnv()
+ .addToVolumeMounts(localDirVolumeMounts: _*)
+ .build()
+ SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala
new file mode 100644
index 0000000000000..97fa9499b2edb
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod}
+
+private[spark] class MountSecretsFeatureStep(
+ kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ extends KubernetesFeatureConfigStep {
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val addedVolumes = kubernetesConf
+ .roleSecretNamesToMountPaths
+ .keys
+ .map(secretName =>
+ new VolumeBuilder()
+ .withName(secretVolumeName(secretName))
+ .withNewSecret()
+ .withSecretName(secretName)
+ .endSecret()
+ .build())
+ val podWithVolumes = new PodBuilder(pod.pod)
+ .editOrNewSpec()
+ .addToVolumes(addedVolumes.toSeq: _*)
+ .endSpec()
+ .build()
+ val addedVolumeMounts = kubernetesConf
+ .roleSecretNamesToMountPaths
+ .map {
+ case (secretName, mountPath) =>
+ new VolumeMountBuilder()
+ .withName(secretVolumeName(secretName))
+ .withMountPath(mountPath)
+ .build()
+ }
+ val containerWithMounts = new ContainerBuilder(pod.container)
+ .addToVolumeMounts(addedVolumeMounts.toSeq: _*)
+ .build()
+ SparkPod(podWithVolumes, containerWithMounts)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+
+ private def secretVolumeName(secretName: String): String = s"$secretName-volume"
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala
new file mode 100644
index 0000000000000..f52ec9fdc677e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features.bindings
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH
+import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep
+import org.apache.spark.launcher.SparkLauncher
+
+private[spark] class JavaDriverFeatureStep(
+ kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf])
+ extends KubernetesFeatureConfigStep {
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val withDriverArgs = new ContainerBuilder(pod.container)
+ .addToArgs("driver")
+ .addToArgs("--properties-file", SPARK_CONF_PATH)
+ .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass)
+ // The user application jar is merged into the spark.jars list and managed through that
+ // property, so there is no need to reference it explicitly here.
+ .addToArgs(SparkLauncher.NO_RESOURCE)
+ .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*)
+ .build()
+ SparkPod(pod.pod, withDriverArgs)
+ }
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala
new file mode 100644
index 0000000000000..c20bcac1f8987
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features.bindings
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata}
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod}
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep
+
+private[spark] class PythonDriverFeatureStep(
+ kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf])
+ extends KubernetesFeatureConfigStep {
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val roleConf = kubernetesConf.roleSpecificConf
+ require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined")
+ val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map(
+ pyArgs =>
+ new EnvVarBuilder()
+ .withName(ENV_PYSPARK_ARGS)
+ .withValue(pyArgs.mkString(","))
+ .build())
+ val maybePythonFiles = kubernetesConf.pyFiles().map(
+ // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH
+ // of the respective PySpark pod
+ pyFiles =>
+ new EnvVarBuilder()
+ .withName(ENV_PYSPARK_FILES)
+ .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(","))
+ .mkString(":"))
+ .build())
+ val envSeq =
+ Seq(new EnvVarBuilder()
+ .withName(ENV_PYSPARK_PRIMARY)
+ .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get))
+ .build(),
+ new EnvVarBuilder()
+ .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION)
+ .withValue(kubernetesConf.pySparkPythonVersion())
+ .build())
+ val pythonEnvs = envSeq ++
+ maybePythonArgs.toSeq ++
+ maybePythonFiles.toSeq
+
+ val withPythonPrimaryContainer = new ContainerBuilder(pod.container)
+ .addAllToEnv(pythonEnvs.asJava)
+ .addToArgs("driver-py")
+ .addToArgs("--properties-file", SPARK_CONF_PATH)
+ .addToArgs("--class", roleConf.mainClass)
+ .build()
+
+ SparkPod(pod.pod, withPythonPrimaryContainer)
+ }
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala
deleted file mode 100644
index ae70904621184..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit
-
-import java.util.UUID
-
-import com.google.common.primitives.Longs
-
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.steps._
-import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator
-import org.apache.spark.launcher.SparkLauncher
-import org.apache.spark.util.SystemClock
-import org.apache.spark.util.Utils
-
-/**
- * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to
- * configure the Spark driver pod. The returned steps will be applied one by one in the given
- * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication
- * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to
- * configure the driver init-container if one is needed, i.e., when there are remote dependencies
- * to localize.
- */
-private[spark] class DriverConfigOrchestrator(
- kubernetesAppId: String,
- launchTime: Long,
- mainAppResource: Option[MainAppResource],
- appName: String,
- mainClass: String,
- appArgs: Array[String],
- sparkConf: SparkConf) {
-
- // The resource name prefix is derived from the Spark application name, making it easy to connect
- // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the
- // application the user submitted.
- private val kubernetesResourceNamePrefix = {
- val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "")
- s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-")
- }
-
- private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY)
- private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config"
- private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION)
- private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION)
-
- def getAllConfigurationSteps: Seq[DriverConfigurationStep] = {
- val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_DRIVER_LABEL_PREFIX)
- require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " +
- s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " +
- "operations.")
- require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " +
- s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " +
- "operations.")
-
- val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_DRIVER_SECRETS_PREFIX)
-
- val allDriverLabels = driverCustomLabels ++ Map(
- SPARK_APP_ID_LABEL -> kubernetesAppId,
- SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE)
-
- val initialSubmissionStep = new BasicDriverConfigurationStep(
- kubernetesAppId,
- kubernetesResourceNamePrefix,
- allDriverLabels,
- imagePullPolicy,
- appName,
- mainClass,
- appArgs,
- sparkConf)
-
- val serviceBootstrapStep = new DriverServiceBootstrapStep(
- kubernetesResourceNamePrefix,
- allDriverLabels,
- sparkConf,
- new SystemClock)
-
- val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep(
- sparkConf, kubernetesResourceNamePrefix)
-
- val additionalMainAppJar = if (mainAppResource.nonEmpty) {
- val mayBeResource = mainAppResource.get match {
- case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE =>
- Some(resource)
- case _ => None
- }
- mayBeResource
- } else {
- None
- }
-
- val sparkJars = sparkConf.getOption("spark.jars")
- .map(_.split(","))
- .getOrElse(Array.empty[String]) ++
- additionalMainAppJar.toSeq
- val sparkFiles = sparkConf.getOption("spark.files")
- .map(_.split(","))
- .getOrElse(Array.empty[String])
-
- // TODO(SPARK-23153): remove once submission client local dependencies are supported.
- if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) {
- throw new SparkException("The Kubernetes mode does not yet support referencing application " +
- "dependencies in the local file system.")
- }
-
- val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) {
- Seq(new DependencyResolutionStep(
- sparkJars,
- sparkFiles,
- jarsDownloadPath,
- filesDownloadPath))
- } else {
- Nil
- }
-
- val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) {
- Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths)))
- } else {
- Nil
- }
-
- val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) {
- val orchestrator = new InitContainerConfigOrchestrator(
- sparkJars,
- sparkFiles,
- jarsDownloadPath,
- filesDownloadPath,
- imagePullPolicy,
- initContainerConfigMapName,
- INIT_CONTAINER_PROPERTIES_FILE_NAME,
- sparkConf)
- val bootstrapStep = new DriverInitContainerBootstrapStep(
- orchestrator.getAllConfigurationSteps,
- initContainerConfigMapName,
- INIT_CONTAINER_PROPERTIES_FILE_NAME)
-
- Seq(bootstrapStep)
- } else {
- Nil
- }
-
- Seq(
- initialSubmissionStep,
- serviceBootstrapStep,
- kubernetesCredentialsStep) ++
- dependencyResolutionStep ++
- mountSecretsStep ++
- initContainerBootstrapStep
- }
-
- private def existSubmissionLocalFiles(files: Seq[String]): Boolean = {
- files.exists { uri =>
- Utils.resolveURI(uri).getScheme == "file"
- }
- }
-
- private def existNonContainerLocalFiles(files: Seq[String]): Boolean = {
- files.exists { uri =>
- Utils.resolveURI(uri).getScheme != "local"
- }
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
index 5884348cb3e41..eaff47205dbbc 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
@@ -16,21 +16,20 @@
*/
package org.apache.spark.deploy.k8s.submit
+import java.io.StringWriter
import java.util.{Collections, UUID}
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import scala.util.control.NonFatal
+import java.util.Properties
import io.fabric8.kubernetes.api.model._
import io.fabric8.kubernetes.client.KubernetesClient
+import scala.collection.mutable
+import scala.util.control.NonFatal
import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkApplication
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory
-import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
@@ -40,11 +39,13 @@ import org.apache.spark.util.Utils
* @param mainAppResource the main application resource if any
* @param mainClass the main class of the application to run
* @param driverArgs arguments to the driver
+ * @param maybePyFiles additional Python files via --py-files
*/
private[spark] case class ClientArguments(
- mainAppResource: Option[MainAppResource],
- mainClass: String,
- driverArgs: Array[String])
+ mainAppResource: Option[MainAppResource],
+ mainClass: String,
+ driverArgs: Array[String],
+ maybePyFiles: Option[String])
private[spark] object ClientArguments {
@@ -52,10 +53,15 @@ private[spark] object ClientArguments {
var mainAppResource: Option[MainAppResource] = None
var mainClass: Option[String] = None
val driverArgs = mutable.ArrayBuffer.empty[String]
+ var maybePyFiles : Option[String] = None
args.sliding(2, 2).toList.foreach {
case Array("--primary-java-resource", primaryJavaResource: String) =>
mainAppResource = Some(JavaMainAppResource(primaryJavaResource))
+ case Array("--primary-py-file", primaryPythonResource: String) =>
+ mainAppResource = Some(PythonMainAppResource(primaryPythonResource))
+ case Array("--other-py-files", pyFiles: String) =>
+ maybePyFiles = Some(pyFiles)
case Array("--main-class", clazz: String) =>
mainClass = Some(clazz)
case Array("--arg", arg: String) =>
@@ -70,7 +76,8 @@ private[spark] object ClientArguments {
ClientArguments(
mainAppResource,
mainClass.get,
- driverArgs.toArray)
+ driverArgs.toArray,
+ maybePyFiles)
}
}
@@ -79,8 +86,9 @@ private[spark] object ClientArguments {
* watcher that monitors and logs the application status. Waits for the application to terminate if
* spark.kubernetes.submission.waitAppCompletion is true.
*
- * @param submissionSteps steps that collectively configure the driver
- * @param sparkConf the submission client Spark configuration
+ * @param builder Responsible for building the base driver pod based on a composition of
+ * implemented features.
+ * @param kubernetesConf application configuration
* @param kubernetesClient the client to talk to the Kubernetes API server
* @param waitForAppCompletion a flag indicating whether the client should wait for the application
* to complete
@@ -88,55 +96,41 @@ private[spark] object ClientArguments {
* @param watcher a watcher that monitors and logs the application status
*/
private[spark] class Client(
- submissionSteps: Seq[DriverConfigurationStep],
- sparkConf: SparkConf,
+ builder: KubernetesDriverBuilder,
+ kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf],
kubernetesClient: KubernetesClient,
waitForAppCompletion: Boolean,
appName: String,
- watcher: LoggingPodStatusWatcher) extends Logging {
-
- private val driverJavaOptions = sparkConf.get(
- org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS)
+ watcher: LoggingPodStatusWatcher,
+ kubernetesResourceNamePrefix: String) extends Logging {
- /**
- * Run command that initializes a DriverSpec that will be updated after each
- * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec
- * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources
- */
def run(): Unit = {
- var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf)
- // submissionSteps contain steps necessary to take, to resolve varying
- // client arguments that are passed in, created by orchestrator
- for (nextStep <- submissionSteps) {
- currentDriverSpec = nextStep.configureDriver(currentDriverSpec)
- }
-
- val resolvedDriverJavaOpts = currentDriverSpec
- .driverSparkConf
- // Remove this as the options are instead extracted and set individually below using
- // environment variables with prefix SPARK_JAVA_OPT_.
- .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS)
- .getAll
- .map {
- case (confKey, confValue) => s"-D$confKey=$confValue"
- } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty)
- val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map {
- case (option, index) =>
- new EnvVarBuilder()
- .withName(s"$ENV_JAVA_OPT_PREFIX$index")
- .withValue(option)
- .build()
- }
-
- val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer)
- .addAllToEnv(driverJavaOptsEnvs.asJava)
+ val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf)
+ val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map"
+ val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties)
+ // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the
+ // Spark command builder to pickup on the Java Options present in the ConfigMap
+ val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container)
+ .addNewEnv()
+ .withName(ENV_SPARK_CONF_DIR)
+ .withValue(SPARK_CONF_DIR_INTERNAL)
+ .endEnv()
+ .addNewVolumeMount()
+ .withName(SPARK_CONF_VOLUME)
+ .withMountPath(SPARK_CONF_DIR_INTERNAL)
+ .endVolumeMount()
.build()
- val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod)
+ val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod)
.editSpec()
.addToContainers(resolvedDriverContainer)
+ .addNewVolume()
+ .withName(SPARK_CONF_VOLUME)
+ .withNewConfigMap()
+ .withName(configMapName)
+ .endConfigMap()
+ .endVolume()
.endSpec()
.build()
-
Utils.tryWithResource(
kubernetesClient
.pods()
@@ -144,11 +138,10 @@ private[spark] class Client(
.watch(watcher)) { _ =>
val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod)
try {
- if (currentDriverSpec.otherKubernetesResources.nonEmpty) {
- val otherKubernetesResources = currentDriverSpec.otherKubernetesResources
- addDriverOwnerReference(createdDriverPod, otherKubernetesResources)
- kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace()
- }
+ val otherKubernetesResources =
+ resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap)
+ addDriverOwnerReference(createdDriverPod, otherKubernetesResources)
+ kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace()
} catch {
case NonFatal(e) =>
kubernetesClient.pods().delete(createdDriverPod)
@@ -180,6 +173,23 @@ private[spark] class Client(
originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference))
}
}
+
+ // Build a Config Map that will house spark conf properties in a single file for spark-submit
+ private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = {
+ val properties = new Properties()
+ conf.foreach { case (k, v) =>
+ properties.setProperty(k, v)
+ }
+ val propertiesWriter = new StringWriter()
+ properties.store(propertiesWriter,
+ s"Java properties built from Kubernetes config map with name: $configMapName")
+ new ConfigMapBuilder()
+ .withNewMetadata()
+ .withName(configMapName)
+ .endMetadata()
+ .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString)
+ .build()
+ }
}
/**
@@ -193,7 +203,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication {
}
private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = {
- val namespace = sparkConf.get(KUBERNETES_NAMESPACE)
+ val appName = sparkConf.getOption("spark.app.name").getOrElse("spark")
// For constructing the app ID, we can't use the Spark application name, as the app ID is going
// to be added as a label to group resources belonging to the same application. Label values are
// considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate
@@ -201,7 +211,21 @@ private[spark] class KubernetesClientApplication extends SparkApplication {
val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}"
val launchTime = System.currentTimeMillis()
val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION)
- val appName = sparkConf.getOption("spark.app.name").getOrElse("spark")
+ val kubernetesResourceNamePrefix = {
+ s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-")
+ }
+ sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse(""))
+ val kubernetesConf = KubernetesConf.createDriverConf(
+ sparkConf,
+ appName,
+ kubernetesResourceNamePrefix,
+ kubernetesAppId,
+ clientArguments.mainAppResource,
+ clientArguments.mainClass,
+ clientArguments.driverArgs,
+ clientArguments.maybePyFiles)
+ val builder = new KubernetesDriverBuilder
+ val namespace = kubernetesConf.namespace()
// The master URL has been checked for validity already in SparkSubmit.
// We just need to get rid of the "k8s://" prefix here.
val master = sparkConf.get("spark.master").substring("k8s://".length)
@@ -209,15 +233,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication {
val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval)
- val orchestrator = new DriverConfigOrchestrator(
- kubernetesAppId,
- launchTime,
- clientArguments.mainAppResource,
- appName,
- clientArguments.mainClass,
- clientArguments.driverArgs,
- sparkConf)
-
Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient(
master,
Some(namespace),
@@ -226,12 +241,13 @@ private[spark] class KubernetesClientApplication extends SparkApplication {
None,
None)) { kubernetesClient =>
val client = new Client(
- orchestrator.getAllConfigurationSteps,
- sparkConf,
+ builder,
+ kubernetesConf,
kubernetesClient,
waitForAppCompletion,
appName,
- watcher)
+ watcher,
+ kubernetesResourceNamePrefix)
client.run()
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
new file mode 100644
index 0000000000000..5762d8245f778
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.submit
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf}
+import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep}
+
+private[spark] class KubernetesDriverBuilder(
+ provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep =
+ new BasicDriverFeatureStep(_),
+ provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf])
+ => DriverKubernetesCredentialsFeatureStep =
+ new DriverKubernetesCredentialsFeatureStep(_),
+ provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep =
+ new DriverServiceFeatureStep(_),
+ provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
+ => MountSecretsFeatureStep) =
+ new MountSecretsFeatureStep(_),
+ provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
+ => EnvSecretsFeatureStep) =
+ new EnvSecretsFeatureStep(_),
+ provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
+ => LocalDirsFeatureStep) =
+ new LocalDirsFeatureStep(_),
+ provideJavaStep: (
+ KubernetesConf[KubernetesDriverSpecificConf]
+ => JavaDriverFeatureStep) =
+ new JavaDriverFeatureStep(_),
+ providePythonStep: (
+ KubernetesConf[KubernetesDriverSpecificConf]
+ => PythonDriverFeatureStep) =
+ new PythonDriverFeatureStep(_)) {
+
+ def buildFromFeatures(
+ kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = {
+ val baseFeatures = Seq(
+ provideBasicStep(kubernetesConf),
+ provideCredentialsStep(kubernetesConf),
+ provideServiceStep(kubernetesConf),
+ provideLocalDirsStep(kubernetesConf))
+
+ val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
+ Some(provideSecretsStep(kubernetesConf)) } else None
+
+ val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
+ Some(provideEnvSecretsStep(kubernetesConf)) } else None
+
+ val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map {
+ case JavaMainAppResource(_) =>
+ provideJavaStep(kubernetesConf)
+ case PythonMainAppResource(_) =>
+ providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf))
+
+ val allFeatures: Seq[KubernetesFeatureConfigStep] =
+ (baseFeatures :+ bindingsStep) ++
+ maybeRoleSecretNamesStep.toSeq ++
+ maybeProvideSecretsStep.toSeq
+
+ var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap)
+ for (feature <- allFeatures) {
+ val configuredPod = feature.configurePod(spec.pod)
+ val addedSystemProperties = feature.getAdditionalPodSystemProperties()
+ val addedResources = feature.getAdditionalKubernetesResources()
+ spec = KubernetesDriverSpec(
+ configuredPod,
+ spec.driverKubernetesResources ++ addedResources,
+ spec.systemProperties ++ addedSystemProperties)
+ }
+ spec
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala
deleted file mode 100644
index db13f09387ef9..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit
-
-import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder}
-
-import org.apache.spark.SparkConf
-
-/**
- * Represents the components and characteristics of a Spark driver. The driver can be considered
- * as being comprised of the driver pod itself, any other Kubernetes resources that the driver
- * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver
- * container should be operated on via the specific field of this case class as opposed to trying
- * to edit the container directly on the pod. The driver container should be attached at the
- * end of executing all submission steps.
- */
-private[spark] case class KubernetesDriverSpec(
- driverPod: Pod,
- driverContainer: Container,
- otherKubernetesResources: Seq[HasMetadata],
- driverSparkConf: SparkConf)
-
-private[spark] object KubernetesDriverSpec {
- def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = {
- KubernetesDriverSpec(
- // Set new metadata and a new spec so that submission steps can use
- // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely.
- new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(),
- new ContainerBuilder().build(),
- Seq.empty[HasMetadata],
- initialSparkConf.clone())
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala
index cca9f4627a1f6..cbe081ae35683 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala
@@ -18,4 +18,9 @@ package org.apache.spark.deploy.k8s.submit
private[spark] sealed trait MainAppResource
+private[spark] sealed trait NonJVMResource
+
private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource
+
+private[spark] case class PythonMainAppResource(primaryResource: String)
+ extends MainAppResource with NonJVMResource
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala
deleted file mode 100644
index 164e2e5594778..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder}
-
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.KubernetesUtils
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD}
-
-/**
- * Performs basic configuration for the driver pod.
- */
-private[spark] class BasicDriverConfigurationStep(
- kubernetesAppId: String,
- resourceNamePrefix: String,
- driverLabels: Map[String, String],
- imagePullPolicy: String,
- appName: String,
- mainClass: String,
- appArgs: Array[String],
- sparkConf: SparkConf) extends DriverConfigurationStep {
-
- private val driverPodName = sparkConf
- .get(KUBERNETES_DRIVER_POD_NAME)
- .getOrElse(s"$resourceNamePrefix-driver")
-
- private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH)
-
- private val driverContainerImage = sparkConf
- .get(DRIVER_CONTAINER_IMAGE)
- .getOrElse(throw new SparkException("Must specify the driver container image"))
-
- // CPU settings
- private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1")
- private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES)
-
- // Memory settings
- private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY)
- private val driverMemoryString = sparkConf.get(
- DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString)
- private val memoryOverheadMiB = sparkConf
- .get(DRIVER_MEMORY_OVERHEAD)
- .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB))
- private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val driverExtraClasspathEnv = driverExtraClasspath.map { classPath =>
- new EnvVarBuilder()
- .withName(ENV_CLASSPATH)
- .withValue(classPath)
- .build()
- }
-
- val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX)
- require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION),
- s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" +
- " Spark bookkeeping operations.")
-
- val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq
- .map { env =>
- new EnvVarBuilder()
- .withName(env._1)
- .withValue(env._2)
- .build()
- }
-
- val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName)
-
- val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX)
-
- val driverCpuQuantity = new QuantityBuilder(false)
- .withAmount(driverCpuCores)
- .build()
- val driverMemoryQuantity = new QuantityBuilder(false)
- .withAmount(s"${driverMemoryMiB}Mi")
- .build()
- val driverMemoryLimitQuantity = new QuantityBuilder(false)
- .withAmount(s"${driverMemoryWithOverheadMiB}Mi")
- .build()
- val maybeCpuLimitQuantity = driverLimitCores.map { limitCores =>
- ("cpu", new QuantityBuilder(false).withAmount(limitCores).build())
- }
-
- val driverContainer = new ContainerBuilder(driverSpec.driverContainer)
- .withName(DRIVER_CONTAINER_NAME)
- .withImage(driverContainerImage)
- .withImagePullPolicy(imagePullPolicy)
- .addAllToEnv(driverCustomEnvs.asJava)
- .addToEnv(driverExtraClasspathEnv.toSeq: _*)
- .addNewEnv()
- .withName(ENV_DRIVER_MEMORY)
- .withValue(driverMemoryString)
- .endEnv()
- .addNewEnv()
- .withName(ENV_DRIVER_MAIN_CLASS)
- .withValue(mainClass)
- .endEnv()
- .addNewEnv()
- .withName(ENV_DRIVER_ARGS)
- .withValue(appArgs.mkString(" "))
- .endEnv()
- .addNewEnv()
- .withName(ENV_DRIVER_BIND_ADDRESS)
- .withValueFrom(new EnvVarSourceBuilder()
- .withNewFieldRef("v1", "status.podIP")
- .build())
- .endEnv()
- .withNewResources()
- .addToRequests("cpu", driverCpuQuantity)
- .addToRequests("memory", driverMemoryQuantity)
- .addToLimits("memory", driverMemoryLimitQuantity)
- .addToLimits(maybeCpuLimitQuantity.toMap.asJava)
- .endResources()
- .addToArgs("driver")
- .build()
-
- val baseDriverPod = new PodBuilder(driverSpec.driverPod)
- .editOrNewMetadata()
- .withName(driverPodName)
- .addToLabels(driverLabels.asJava)
- .addToAnnotations(driverAnnotations.asJava)
- .endMetadata()
- .withNewSpec()
- .withRestartPolicy("Never")
- .withNodeSelector(nodeSelector.asJava)
- .endSpec()
- .build()
-
- val resolvedSparkConf = driverSpec.driverSparkConf.clone()
- .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName)
- .set("spark.app.id", kubernetesAppId)
- .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix)
-
- driverSpec.copy(
- driverPod = baseDriverPod,
- driverSparkConf = resolvedSparkConf,
- driverContainer = driverContainer)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala
deleted file mode 100644
index d4b83235b4e3b..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import java.io.File
-
-import io.fabric8.kubernetes.api.model.ContainerBuilder
-
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.KubernetesUtils
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-
-/**
- * Step that configures the classpath, spark.jars, and spark.files for the driver given that the
- * user may provide remote files or files with local:// schemes.
- */
-private[spark] class DependencyResolutionStep(
- sparkJars: Seq[String],
- sparkFiles: Seq[String],
- jarsDownloadPath: String,
- filesDownloadPath: String) extends DriverConfigurationStep {
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath)
- val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath)
-
- val sparkConf = driverSpec.driverSparkConf.clone()
- if (resolvedSparkJars.nonEmpty) {
- sparkConf.set("spark.jars", resolvedSparkJars.mkString(","))
- }
- if (resolvedSparkFiles.nonEmpty) {
- sparkConf.set("spark.files", resolvedSparkFiles.mkString(","))
- }
-
- val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath)
- val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) {
- new ContainerBuilder(driverSpec.driverContainer)
- .addNewEnv()
- .withName(ENV_MOUNTED_CLASSPATH)
- .withValue(resolvedClasspath.mkString(File.pathSeparator))
- .endEnv()
- .build()
- } else {
- driverSpec.driverContainer
- }
-
- driverSpec.copy(
- driverContainer = resolvedDriverContainer,
- driverSparkConf = sparkConf)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala
deleted file mode 100644
index 9fb3dafdda540..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import java.io.StringWriter
-import java.util.Properties
-
-import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata}
-
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.KubernetesUtils
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec}
-
-/**
- * Configures the driver init-container that localizes remote dependencies into the driver pod.
- * It applies the given InitContainerConfigurationSteps in the given order to produce a final
- * InitContainerSpec that is then used to configure the driver pod with the init-container attached.
- * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries
- * configuration properties for the init-container.
- */
-private[spark] class DriverInitContainerBootstrapStep(
- steps: Seq[InitContainerConfigurationStep],
- configMapName: String,
- configMapKey: String)
- extends DriverConfigurationStep {
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- var initContainerSpec = InitContainerSpec(
- properties = Map.empty[String, String],
- driverSparkConf = Map.empty[String, String],
- initContainer = new ContainerBuilder().build(),
- driverContainer = driverSpec.driverContainer,
- driverPod = driverSpec.driverPod,
- dependentResources = Seq.empty[HasMetadata])
- for (nextStep <- steps) {
- initContainerSpec = nextStep.configureInitContainer(initContainerSpec)
- }
-
- val configMap = buildConfigMap(
- configMapName,
- configMapKey,
- initContainerSpec.properties)
- val resolvedDriverSparkConf = driverSpec.driverSparkConf
- .clone()
- .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName)
- .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey)
- .setAll(initContainerSpec.driverSparkConf)
- val resolvedDriverPod = KubernetesUtils.appendInitContainer(
- initContainerSpec.driverPod, initContainerSpec.initContainer)
-
- driverSpec.copy(
- driverPod = resolvedDriverPod,
- driverContainer = initContainerSpec.driverContainer,
- driverSparkConf = resolvedDriverSparkConf,
- otherKubernetesResources =
- driverSpec.otherKubernetesResources ++
- initContainerSpec.dependentResources ++
- Seq(configMap))
- }
-
- private def buildConfigMap(
- configMapName: String,
- configMapKey: String,
- config: Map[String, String]): ConfigMap = {
- val properties = new Properties()
- config.foreach { entry =>
- properties.setProperty(entry._1, entry._2)
- }
- val propertiesWriter = new StringWriter()
- properties.store(propertiesWriter,
- s"Java properties built from Kubernetes config map with name: $configMapName " +
- s"and config map key: $configMapKey")
- new ConfigMapBuilder()
- .withNewMetadata()
- .withName(configMapName)
- .endMetadata()
- .addToData(configMapKey, propertiesWriter.toString)
- .build()
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala
deleted file mode 100644
index ccc18908658f1..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala
+++ /dev/null
@@ -1,245 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import java.io.File
-import java.nio.charset.StandardCharsets
-
-import scala.collection.JavaConverters._
-import scala.language.implicitConversions
-
-import com.google.common.io.{BaseEncoding, Files}
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-
-/**
- * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials
- * to request executors.
- */
-private[spark] class DriverKubernetesCredentialsStep(
- submissionSparkConf: SparkConf,
- kubernetesResourceNamePrefix: String) extends DriverConfigurationStep {
-
- private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX")
- private val maybeMountedClientKeyFile = submissionSparkConf.getOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX")
- private val maybeMountedClientCertFile = submissionSparkConf.getOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX")
- private val maybeMountedCaCertFile = submissionSparkConf.getOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX")
- private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME)
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val driverSparkConf = driverSpec.driverSparkConf.clone()
-
- val oauthTokenBase64 = submissionSparkConf
- .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX")
- .map { token =>
- BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8))
- }
- val caCertDataBase64 = safeFileConfToBase64(
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX",
- "Driver CA cert file")
- val clientKeyDataBase64 = safeFileConfToBase64(
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX",
- "Driver client key file")
- val clientCertDataBase64 = safeFileConfToBase64(
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX",
- "Driver client cert file")
-
- val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations(
- driverSparkConf,
- oauthTokenBase64,
- caCertDataBase64,
- clientKeyDataBase64,
- clientCertDataBase64)
-
- val kubernetesCredentialsSecret = createCredentialsSecret(
- oauthTokenBase64,
- caCertDataBase64,
- clientKeyDataBase64,
- clientCertDataBase64)
-
- val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret =>
- new PodBuilder(driverSpec.driverPod)
- .editOrNewSpec()
- .addNewVolume()
- .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
- .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret()
- .endVolume()
- .endSpec()
- .build()
- }.getOrElse(
- driverServiceAccount.map { account =>
- new PodBuilder(driverSpec.driverPod)
- .editOrNewSpec()
- .withServiceAccount(account)
- .withServiceAccountName(account)
- .endSpec()
- .build()
- }.getOrElse(driverSpec.driverPod)
- )
-
- val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret =>
- new ContainerBuilder(driverSpec.driverContainer)
- .addNewVolumeMount()
- .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
- .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR)
- .endVolumeMount()
- .build()
- }.getOrElse(driverSpec.driverContainer)
-
- driverSpec.copy(
- driverPod = driverPodWithMountedKubernetesCredentials,
- otherKubernetesResources =
- driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq,
- driverSparkConf = driverSparkConfWithCredentialsLocations,
- driverContainer = driverContainerWithMountedSecretVolume)
- }
-
- private def createCredentialsSecret(
- driverOAuthTokenBase64: Option[String],
- driverCaCertDataBase64: Option[String],
- driverClientKeyDataBase64: Option[String],
- driverClientCertDataBase64: Option[String]): Option[Secret] = {
- val allSecretData =
- resolveSecretData(
- driverClientKeyDataBase64,
- DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++
- resolveSecretData(
- driverClientCertDataBase64,
- DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++
- resolveSecretData(
- driverCaCertDataBase64,
- DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++
- resolveSecretData(
- driverOAuthTokenBase64,
- DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME)
-
- if (allSecretData.isEmpty) {
- None
- } else {
- Some(new SecretBuilder()
- .withNewMetadata()
- .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials")
- .endMetadata()
- .withData(allSecretData.asJava)
- .build())
- }
- }
-
- private def setDriverPodKubernetesCredentialLocations(
- driverSparkConf: SparkConf,
- driverOauthTokenBase64: Option[String],
- driverCaCertDataBase64: Option[String],
- driverClientKeyDataBase64: Option[String],
- driverClientCertDataBase64: Option[String]): SparkConf = {
- val resolvedMountedOAuthTokenFile = resolveSecretLocation(
- maybeMountedOAuthTokenFile,
- driverOauthTokenBase64,
- DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH)
- val resolvedMountedClientKeyFile = resolveSecretLocation(
- maybeMountedClientKeyFile,
- driverClientKeyDataBase64,
- DRIVER_CREDENTIALS_CLIENT_KEY_PATH)
- val resolvedMountedClientCertFile = resolveSecretLocation(
- maybeMountedClientCertFile,
- driverClientCertDataBase64,
- DRIVER_CREDENTIALS_CLIENT_CERT_PATH)
- val resolvedMountedCaCertFile = resolveSecretLocation(
- maybeMountedCaCertFile,
- driverCaCertDataBase64,
- DRIVER_CREDENTIALS_CA_CERT_PATH)
-
- val sparkConfWithCredentialLocations = driverSparkConf
- .setOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX",
- resolvedMountedCaCertFile)
- .setOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX",
- resolvedMountedClientKeyFile)
- .setOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX",
- resolvedMountedClientCertFile)
- .setOption(
- s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX",
- resolvedMountedOAuthTokenFile)
-
- // Redact all OAuth token values
- sparkConfWithCredentialLocations
- .getAll
- .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1)
- .foreach {
- sparkConfWithCredentialLocations.set(_, "")
- }
- sparkConfWithCredentialLocations
- }
-
- private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = {
- submissionSparkConf.getOption(conf)
- .map(new File(_))
- .map { file =>
- require(file.isFile, String.format("%s provided at %s does not exist or is not a file.",
- fileType, file.getAbsolutePath))
- BaseEncoding.base64().encode(Files.toByteArray(file))
- }
- }
-
- private def resolveSecretLocation(
- mountedUserSpecified: Option[String],
- valueMountedFromSubmitter: Option[String],
- mountedCanonicalLocation: String): Option[String] = {
- mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ =>
- mountedCanonicalLocation
- })
- }
-
- /**
- * Resolve a Kubernetes secret data entry from an optional client credential used by the
- * driver to talk to the Kubernetes API server.
- *
- * @param userSpecifiedCredential the optional user-specified client credential.
- * @param secretName name of the Kubernetes secret storing the client credential.
- * @return a secret data entry in the form of a map from the secret name to the secret data,
- * which may be empty if the user-specified credential is empty.
- */
- private def resolveSecretData(
- userSpecifiedCredential: Option[String],
- secretName: String): Map[String, String] = {
- userSpecifiedCredential.map { valueBase64 =>
- Map(secretName -> valueBase64)
- }.getOrElse(Map.empty[String, String])
- }
-
- private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = {
- new OptionSettableSparkConf(sparkConf)
- }
-}
-
-private class OptionSettableSparkConf(sparkConf: SparkConf) {
- def setOption(configEntry: String, option: Option[String]): SparkConf = {
- option.foreach { opt =>
- sparkConf.set(configEntry, opt)
- }
- sparkConf
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala
deleted file mode 100644
index 34af7cde6c1a9..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.ServiceBuilder
-
-import org.apache.spark.SparkConf
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-import org.apache.spark.internal.Logging
-import org.apache.spark.util.Clock
-
-/**
- * Allows the driver to be reachable by executor pods through a headless service. The service's
- * ports should correspond to the ports that the executor will reach the pod at for RPC.
- */
-private[spark] class DriverServiceBootstrapStep(
- resourceNamePrefix: String,
- driverLabels: Map[String, String],
- sparkConf: SparkConf,
- clock: Clock) extends DriverConfigurationStep with Logging {
-
- import DriverServiceBootstrapStep._
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty,
- s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " +
- "address is managed and set to the driver pod's IP address.")
- require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty,
- s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " +
- "managed via a Kubernetes service.")
-
- val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX"
- val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) {
- preferredServiceName
- } else {
- val randomServiceId = clock.getTimeMillis()
- val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX"
- logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " +
- s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " +
- s"$shorterServiceName as the driver service's name.")
- shorterServiceName
- }
-
- val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT)
- val driverBlockManagerPort = sparkConf.getInt(
- org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT)
- val driverService = new ServiceBuilder()
- .withNewMetadata()
- .withName(resolvedServiceName)
- .endMetadata()
- .withNewSpec()
- .withClusterIP("None")
- .withSelector(driverLabels.asJava)
- .addNewPort()
- .withName(DRIVER_PORT_NAME)
- .withPort(driverPort)
- .withNewTargetPort(driverPort)
- .endPort()
- .addNewPort()
- .withName(BLOCK_MANAGER_PORT_NAME)
- .withPort(driverBlockManagerPort)
- .withNewTargetPort(driverBlockManagerPort)
- .endPort()
- .endSpec()
- .build()
-
- val namespace = sparkConf.get(KUBERNETES_NAMESPACE)
- val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc"
- val resolvedSparkConf = driverSpec.driverSparkConf.clone()
- .set(DRIVER_HOST_KEY, driverHostname)
- .set("spark.driver.port", driverPort.toString)
- .set(
- org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort)
-
- driverSpec.copy(
- driverSparkConf = resolvedSparkConf,
- otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService))
- }
-}
-
-private[spark] object DriverServiceBootstrapStep {
- val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key
- val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key
- val DRIVER_SVC_POSTFIX = "-driver-svc"
- val MAX_SERVICE_NAME_LENGTH = 63
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala
deleted file mode 100644
index 01469853dacc2..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.KubernetesUtils
-
-/**
- * Performs basic configuration for the driver init-container with most of the work delegated to
- * the given InitContainerBootstrap.
- */
-private[spark] class BasicInitContainerConfigurationStep(
- sparkJars: Seq[String],
- sparkFiles: Seq[String],
- jarsDownloadPath: String,
- filesDownloadPath: String,
- bootstrap: InitContainerBootstrap)
- extends InitContainerConfigurationStep {
-
- override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = {
- val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars)
- val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles)
- val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) {
- Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(","))
- } else {
- Map()
- }
- val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) {
- Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(","))
- } else {
- Map()
- }
-
- val baseInitContainerConfig = Map(
- JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath,
- FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++
- remoteJarsConf ++
- remoteFilesConf
-
- val bootstrapped = bootstrap.bootstrapInitContainer(
- PodWithDetachedInitContainer(
- spec.driverPod,
- spec.initContainer,
- spec.driverContainer))
-
- spec.copy(
- initContainer = bootstrapped.initContainer,
- driverContainer = bootstrapped.mainContainer,
- driverPod = bootstrapped.pod,
- properties = spec.properties ++ baseInitContainerConfig)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala
deleted file mode 100644
index f2c29c7ce1076..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-
-/**
- * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to
- * configure the driver init-container. The returned steps will be applied in the given order to
- * produce a final InitContainerSpec that is used to construct the driver init-container in
- * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e.,
- * when there are remote application dependencies to localize.
- */
-private[spark] class InitContainerConfigOrchestrator(
- sparkJars: Seq[String],
- sparkFiles: Seq[String],
- jarsDownloadPath: String,
- filesDownloadPath: String,
- imagePullPolicy: String,
- configMapName: String,
- configMapKey: String,
- sparkConf: SparkConf) {
-
- private val initContainerImage = sparkConf
- .get(INIT_CONTAINER_IMAGE)
- .getOrElse(throw new SparkException(
- "Must specify the init-container image when there are remote dependencies"))
-
- def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = {
- val initContainerBootstrap = new InitContainerBootstrap(
- initContainerImage,
- imagePullPolicy,
- jarsDownloadPath,
- filesDownloadPath,
- configMapName,
- configMapKey,
- SPARK_POD_DRIVER_ROLE,
- sparkConf)
- val baseStep = new BasicInitContainerConfigurationStep(
- sparkJars,
- sparkFiles,
- jarsDownloadPath,
- filesDownloadPath,
- initContainerBootstrap)
-
- val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_DRIVER_SECRETS_PREFIX)
- // Mount user-specified driver secrets also into the driver's init-container. The
- // init-container may need credentials in the secrets to be able to download remote
- // dependencies. The driver's main container and its init-container share the secrets
- // because the init-container is sort of an implementation details and this sharing
- // avoids introducing a dedicated configuration property just for the init-container.
- val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) {
- Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths)))
- } else {
- Nil
- }
-
- Seq(baseStep) ++ mountSecretsStep
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala
deleted file mode 100644
index b52c343f0c0ed..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod}
-
-/**
- * Represents a specification of the init-container for the driver pod.
- *
- * @param properties properties that should be set on the init-container
- * @param driverSparkConf Spark configuration properties that will be carried back to the driver
- * @param initContainer the init-container object
- * @param driverContainer the driver container object
- * @param driverPod the driver pod object
- * @param dependentResources resources the init-container depends on to work
- */
-private[spark] case class InitContainerSpec(
- properties: Map[String, String],
- driverSparkConf: Map[String, String],
- initContainer: Container,
- driverContainer: Container,
- driverPod: Pod,
- dependentResources: Seq[HasMetadata])
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
deleted file mode 100644
index 141bd2827e7c5..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala
+++ /dev/null
@@ -1,251 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.scheduler.cluster.k8s
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model._
-
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD}
-import org.apache.spark.util.Utils
-
-/**
- * A factory class for bootstrapping and creating executor pods with the given bootstrapping
- * components.
- *
- * @param sparkConf Spark configuration
- * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto
- * user-specified paths into the executor container
- * @param initContainerBootstrap an optional component for bootstrapping the executor init-container
- * if one is needed, i.e., when there are remote dependencies to
- * localize
- * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified
- * secrets onto user-specified paths into the executor
- * init-container
- */
-private[spark] class ExecutorPodFactory(
- sparkConf: SparkConf,
- mountSecretsBootstrap: Option[MountSecretsBootstrap],
- initContainerBootstrap: Option[InitContainerBootstrap],
- initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) {
-
- private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH)
-
- private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_EXECUTOR_LABEL_PREFIX)
- require(
- !executorLabels.contains(SPARK_APP_ID_LABEL),
- s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.")
- require(
- !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL),
- s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" +
- " Spark.")
- require(
- !executorLabels.contains(SPARK_ROLE_LABEL),
- s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.")
-
- private val executorAnnotations =
- KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_EXECUTOR_ANNOTATION_PREFIX)
- private val nodeSelector =
- KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf,
- KUBERNETES_NODE_SELECTOR_PREFIX)
-
- private val executorContainerImage = sparkConf
- .get(EXECUTOR_CONTAINER_IMAGE)
- .getOrElse(throw new SparkException("Must specify the executor container image"))
- private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY)
- private val blockManagerPort = sparkConf
- .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT)
-
- private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX)
-
- private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY)
- private val executorMemoryString = sparkConf.get(
- EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString)
-
- private val memoryOverheadMiB = sparkConf
- .get(EXECUTOR_MEMORY_OVERHEAD)
- .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt,
- MEMORY_OVERHEAD_MIN_MIB))
- private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB
-
- private val executorCores = sparkConf.getDouble("spark.executor.cores", 1)
- private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES)
-
- private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION)
-
- /**
- * Configure and construct an executor pod with the given parameters.
- */
- def createExecutorPod(
- executorId: String,
- applicationId: String,
- driverUrl: String,
- executorEnvs: Seq[(String, String)],
- driverPod: Pod,
- nodeToLocalTaskCount: Map[String, Int]): Pod = {
- val name = s"$executorPodNamePrefix-exec-$executorId"
-
- // hostname must be no longer than 63 characters, so take the last 63 characters of the pod
- // name as the hostname. This preserves uniqueness since the end of name contains
- // executorId
- val hostname = name.substring(Math.max(0, name.length - 63))
- val resolvedExecutorLabels = Map(
- SPARK_EXECUTOR_ID_LABEL -> executorId,
- SPARK_APP_ID_LABEL -> applicationId,
- SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++
- executorLabels
- val executorMemoryQuantity = new QuantityBuilder(false)
- .withAmount(s"${executorMemoryMiB}Mi")
- .build()
- val executorMemoryLimitQuantity = new QuantityBuilder(false)
- .withAmount(s"${executorMemoryWithOverhead}Mi")
- .build()
- val executorCpuQuantity = new QuantityBuilder(false)
- .withAmount(executorCores.toString)
- .build()
- val executorExtraClasspathEnv = executorExtraClasspath.map { cp =>
- new EnvVarBuilder()
- .withName(ENV_CLASSPATH)
- .withValue(cp)
- .build()
- }
- val executorExtraJavaOptionsEnv = sparkConf
- .get(EXECUTOR_JAVA_OPTIONS)
- .map { opts =>
- val delimitedOpts = Utils.splitCommandString(opts)
- delimitedOpts.zipWithIndex.map {
- case (opt, index) =>
- new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build()
- }
- }.getOrElse(Seq.empty[EnvVar])
- val executorEnv = (Seq(
- (ENV_DRIVER_URL, driverUrl),
- // Executor backend expects integral value for executor cores, so round it up to an int.
- (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString),
- (ENV_EXECUTOR_MEMORY, executorMemoryString),
- (ENV_APPLICATION_ID, applicationId),
- (ENV_EXECUTOR_ID, executorId),
- (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs)
- .map(env => new EnvVarBuilder()
- .withName(env._1)
- .withValue(env._2)
- .build()
- ) ++ Seq(
- new EnvVarBuilder()
- .withName(ENV_EXECUTOR_POD_IP)
- .withValueFrom(new EnvVarSourceBuilder()
- .withNewFieldRef("v1", "status.podIP")
- .build())
- .build()
- ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq
- val requiredPorts = Seq(
- (BLOCK_MANAGER_PORT_NAME, blockManagerPort))
- .map { case (name, port) =>
- new ContainerPortBuilder()
- .withName(name)
- .withContainerPort(port)
- .build()
- }
-
- val executorContainer = new ContainerBuilder()
- .withName("executor")
- .withImage(executorContainerImage)
- .withImagePullPolicy(imagePullPolicy)
- .withNewResources()
- .addToRequests("memory", executorMemoryQuantity)
- .addToLimits("memory", executorMemoryLimitQuantity)
- .addToRequests("cpu", executorCpuQuantity)
- .endResources()
- .addAllToEnv(executorEnv.asJava)
- .withPorts(requiredPorts.asJava)
- .addToArgs("executor")
- .build()
-
- val executorPod = new PodBuilder()
- .withNewMetadata()
- .withName(name)
- .withLabels(resolvedExecutorLabels.asJava)
- .withAnnotations(executorAnnotations.asJava)
- .withOwnerReferences()
- .addNewOwnerReference()
- .withController(true)
- .withApiVersion(driverPod.getApiVersion)
- .withKind(driverPod.getKind)
- .withName(driverPod.getMetadata.getName)
- .withUid(driverPod.getMetadata.getUid)
- .endOwnerReference()
- .endMetadata()
- .withNewSpec()
- .withHostname(hostname)
- .withRestartPolicy("Never")
- .withNodeSelector(nodeSelector.asJava)
- .endSpec()
- .build()
-
- val containerWithLimitCores = executorLimitCores.map { limitCores =>
- val executorCpuLimitQuantity = new QuantityBuilder(false)
- .withAmount(limitCores)
- .build()
- new ContainerBuilder(executorContainer)
- .editResources()
- .addToLimits("cpu", executorCpuLimitQuantity)
- .endResources()
- .build()
- }.getOrElse(executorContainer)
-
- val (maybeSecretsMountedPod, maybeSecretsMountedContainer) =
- mountSecretsBootstrap.map { bootstrap =>
- (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores))
- }.getOrElse((executorPod, containerWithLimitCores))
-
- val (bootstrappedPod, bootstrappedContainer) =
- initContainerBootstrap.map { bootstrap =>
- val podWithInitContainer = bootstrap.bootstrapInitContainer(
- PodWithDetachedInitContainer(
- maybeSecretsMountedPod,
- new ContainerBuilder().build(),
- maybeSecretsMountedContainer))
-
- val (pod, mayBeSecretsMountedInitContainer) =
- initContainerMountSecretsBootstrap.map { bootstrap =>
- // Mount the secret volumes given that the volumes have already been added to the
- // executor pod when mounting the secrets into the main executor container.
- (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer))
- }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer))
-
- val bootstrappedPod = KubernetesUtils.appendInitContainer(
- pod, mayBeSecretsMountedInitContainer)
-
- (bootstrappedPod, podWithInitContainer.mainContainer)
- }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer))
-
- new PodBuilder(bootstrappedPod)
- .editSpec()
- .addToContainers(bootstrappedContainer)
- .endSpec()
- .build()
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala
new file mode 100644
index 0000000000000..83daddf714489
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.Pod
+
+sealed trait ExecutorPodState {
+ def pod: Pod
+}
+
+case class PodRunning(pod: Pod) extends ExecutorPodState
+
+case class PodPending(pod: Pod) extends ExecutorPodState
+
+sealed trait FinalPodState extends ExecutorPodState
+
+case class PodSucceeded(pod: Pod) extends FinalPodState
+
+case class PodFailed(pod: Pod) extends FinalPodState
+
+case class PodDeleted(pod: Pod) extends FinalPodState
+
+case class PodUnknown(pod: Pod) extends ExecutorPodState
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
new file mode 100644
index 0000000000000..5a143ad3600fd
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
+
+import io.fabric8.kubernetes.api.model.PodBuilder
+import io.fabric8.kubernetes.client.KubernetesClient
+import scala.collection.mutable
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.KubernetesConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{Clock, Utils}
+
+private[spark] class ExecutorPodsAllocator(
+ conf: SparkConf,
+ executorBuilder: KubernetesExecutorBuilder,
+ kubernetesClient: KubernetesClient,
+ snapshotsStore: ExecutorPodsSnapshotsStore,
+ clock: Clock) extends Logging {
+
+ private val EXECUTOR_ID_COUNTER = new AtomicLong(0L)
+
+ private val totalExpectedExecutors = new AtomicInteger(0)
+
+ private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE)
+
+ private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY)
+
+ private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000)
+
+ private val kubernetesDriverPodName = conf
+ .get(KUBERNETES_DRIVER_POD_NAME)
+ .getOrElse(throw new SparkException("Must specify the driver pod name"))
+
+ private val driverPod = kubernetesClient.pods()
+ .withName(kubernetesDriverPodName)
+ .get()
+
+ // Executor IDs that have been requested from Kubernetes but have not been detected in any
+ // snapshot yet. Mapped to the timestamp when they were created.
+ private val newlyCreatedExecutors = mutable.Map.empty[Long, Long]
+
+ def start(applicationId: String): Unit = {
+ snapshotsStore.addSubscriber(podAllocationDelay) {
+ onNewSnapshots(applicationId, _)
+ }
+ }
+
+ def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total)
+
+ private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = {
+ newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys)
+ // For all executors we've created against the API but have not seen in a snapshot
+ // yet - check the current time. If the current time has exceeded some threshold,
+ // assume that the pod was either never created (the API server never properly
+ // handled the creation request), or the API server created the pod but we missed
+ // both the creation and deletion events. In either case, delete the missing pod
+ // if possible, and mark such a pod to be rescheduled below.
+ newlyCreatedExecutors.foreach { case (execId, timeCreated) =>
+ val currentTime = clock.getTimeMillis()
+ if (currentTime - timeCreated > podCreationTimeout) {
+ logWarning(s"Executor with id $execId was not detected in the Kubernetes" +
+ s" cluster after $podCreationTimeout milliseconds despite the fact that a" +
+ " previous allocation attempt tried to create it. The executor may have been" +
+ " deleted but the application missed the deletion event.")
+ Utils.tryLogNonFatalError {
+ kubernetesClient
+ .pods()
+ .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString)
+ .delete()
+ }
+ newlyCreatedExecutors -= execId
+ } else {
+ logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" +
+ s" was created ${currentTime - timeCreated} milliseconds ago.")
+ }
+ }
+
+ if (snapshots.nonEmpty) {
+ // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if
+ // we need to allocate more executors or not.
+ val latestSnapshot = snapshots.last
+ val currentRunningExecutors = latestSnapshot.executorPods.values.count {
+ case PodRunning(_) => true
+ case _ => false
+ }
+ val currentPendingExecutors = latestSnapshot.executorPods.values.count {
+ case PodPending(_) => true
+ case _ => false
+ }
+ val currentTotalExpectedExecutors = totalExpectedExecutors.get
+ logDebug(s"Currently have $currentRunningExecutors running executors and" +
+ s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" +
+ s" have been requested but are pending appearance in the cluster.")
+ if (newlyCreatedExecutors.isEmpty
+ && currentPendingExecutors == 0
+ && currentRunningExecutors < currentTotalExpectedExecutors) {
+ val numExecutorsToAllocate = math.min(
+ currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize)
+ logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.")
+ for ( _ <- 0 until numExecutorsToAllocate) {
+ val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet()
+ val executorConf = KubernetesConf.createExecutorConf(
+ conf,
+ newExecutorId.toString,
+ applicationId,
+ driverPod)
+ val executorPod = executorBuilder.buildFromFeatures(executorConf)
+ val podWithAttachedContainer = new PodBuilder(executorPod.pod)
+ .editOrNewSpec()
+ .addToContainers(executorPod.container)
+ .endSpec()
+ .build()
+ kubernetesClient.pods().create(podWithAttachedContainer)
+ newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis()
+ logDebug(s"Requested executor with id $newExecutorId from Kubernetes.")
+ }
+ } else if (currentRunningExecutors >= currentTotalExpectedExecutors) {
+ // TODO handle edge cases if we end up with more running executors than expected.
+ logDebug("Current number of running executors is equal to the number of requested" +
+ " executors. Not scaling up further.")
+ } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) {
+ logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" +
+ s" executors to begin running before requesting for more executors. # of executors in" +
+ s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" +
+ s" created but we have not observed as being present in the cluster yet:" +
+ s" ${newlyCreatedExecutors.size}.")
+ }
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala
new file mode 100644
index 0000000000000..b28d93990313e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import com.google.common.cache.Cache
+import io.fabric8.kubernetes.api.model.Pod
+import io.fabric8.kubernetes.client.KubernetesClient
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler.ExecutorExited
+import org.apache.spark.util.Utils
+
+private[spark] class ExecutorPodsLifecycleManager(
+ conf: SparkConf,
+ executorBuilder: KubernetesExecutorBuilder,
+ kubernetesClient: KubernetesClient,
+ snapshotsStore: ExecutorPodsSnapshotsStore,
+ // Use a best-effort to track which executors have been removed already. It's not generally
+ // job-breaking if we remove executors more than once but it's ideal if we make an attempt
+ // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond
+ // bounds.
+ removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging {
+
+ import ExecutorPodsLifecycleManager._
+
+ private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL)
+
+ def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = {
+ snapshotsStore.addSubscriber(eventProcessingInterval) {
+ onNewSnapshots(schedulerBackend, _)
+ }
+ }
+
+ private def onNewSnapshots(
+ schedulerBackend: KubernetesClusterSchedulerBackend,
+ snapshots: Seq[ExecutorPodsSnapshot]): Unit = {
+ val execIdsRemovedInThisRound = mutable.HashSet.empty[Long]
+ snapshots.foreach { snapshot =>
+ snapshot.executorPods.foreach { case (execId, state) =>
+ state match {
+ case deleted@PodDeleted(_) =>
+ logDebug(s"Snapshot reported deleted executor with id $execId," +
+ s" pod name ${state.pod.getMetadata.getName}")
+ removeExecutorFromSpark(schedulerBackend, deleted, execId)
+ execIdsRemovedInThisRound += execId
+ case failed@PodFailed(_) =>
+ logDebug(s"Snapshot reported failed executor with id $execId," +
+ s" pod name ${state.pod.getMetadata.getName}")
+ onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound)
+ case succeeded@PodSucceeded(_) =>
+ logDebug(s"Snapshot reported succeeded executor with id $execId," +
+ s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" +
+ s" unusual unless Spark specifically informed the executor to exit.")
+ onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound)
+ case _ =>
+ }
+ }
+ }
+
+ // Reconcile the case where Spark claims to know about an executor but the corresponding pod
+ // is missing from the cluster. This would occur if we miss a deletion event and the pod
+ // transitions immediately from running io absent. We only need to check against the latest
+ // snapshot for this, and we don't do this for executors in the deleted executors cache or
+ // that we just removed in this round.
+ if (snapshots.nonEmpty) {
+ val latestSnapshot = snapshots.last
+ (schedulerBackend.getExecutorIds().map(_.toLong).toSet
+ -- latestSnapshot.executorPods.keySet
+ -- execIdsRemovedInThisRound).foreach { missingExecutorId =>
+ if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) {
+ val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" +
+ s" cluster but we didn't get a reason why. Marking the executor as failed. The" +
+ s" executor may have been deleted but the driver missed the deletion event."
+ logDebug(exitReasonMessage)
+ val exitReason = ExecutorExited(
+ UNKNOWN_EXIT_CODE,
+ exitCausedByApp = false,
+ exitReasonMessage)
+ schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason)
+ execIdsRemovedInThisRound += missingExecutorId
+ }
+ }
+ }
+ logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" +
+ s" from Spark that were either found to be deleted or non-existent in the cluster.")
+ }
+
+ private def onFinalNonDeletedState(
+ podState: FinalPodState,
+ execId: Long,
+ schedulerBackend: KubernetesClusterSchedulerBackend,
+ execIdsRemovedInRound: mutable.Set[Long]): Unit = {
+ removeExecutorFromK8s(podState.pod)
+ removeExecutorFromSpark(schedulerBackend, podState, execId)
+ execIdsRemovedInRound += execId
+ }
+
+ private def removeExecutorFromK8s(updatedPod: Pod): Unit = {
+ // If deletion failed on a previous try, we can try again if resync informs us the pod
+ // is still around.
+ // Delete as best attempt - duplicate deletes will throw an exception but the end state
+ // of getting rid of the pod is what matters.
+ Utils.tryLogNonFatalError {
+ kubernetesClient
+ .pods()
+ .withName(updatedPod.getMetadata.getName)
+ .delete()
+ }
+ }
+
+ private def removeExecutorFromSpark(
+ schedulerBackend: KubernetesClusterSchedulerBackend,
+ podState: FinalPodState,
+ execId: Long): Unit = {
+ if (removedExecutorsCache.getIfPresent(execId) == null) {
+ removedExecutorsCache.put(execId, execId)
+ val exitReason = findExitReason(podState, execId)
+ schedulerBackend.doRemoveExecutor(execId.toString, exitReason)
+ }
+ }
+
+ private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = {
+ val exitCode = findExitCode(podState)
+ val (exitCausedByApp, exitMessage) = podState match {
+ case PodDeleted(_) =>
+ (false, s"The executor with id $execId was deleted by a user or the framework.")
+ case _ =>
+ val msg = exitReasonMessage(podState, execId, exitCode)
+ (true, msg)
+ }
+ ExecutorExited(exitCode, exitCausedByApp, exitMessage)
+ }
+
+ private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = {
+ val pod = podState.pod
+ s"""
+ |The executor with id $execId exited with exit code $exitCode.
+ |The API gave the following brief reason: ${pod.getStatus.getReason}
+ |The API gave the following message: ${pod.getStatus.getMessage}
+ |The API gave the following container statuses:
+ |
+ |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")}
+ """.stripMargin
+ }
+
+ private def findExitCode(podState: FinalPodState): Int = {
+ podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus =>
+ containerStatus.getState.getTerminated != null
+ }.map { terminatedContainer =>
+ terminatedContainer.getState.getTerminated.getExitCode.toInt
+ }.getOrElse(UNKNOWN_EXIT_CODE)
+ }
+}
+
+private object ExecutorPodsLifecycleManager {
+ val UNKNOWN_EXIT_CODE = -1
+}
+
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala
new file mode 100644
index 0000000000000..e77e604d00e0f
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit}
+
+import io.fabric8.kubernetes.client.KubernetesClient
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.ThreadUtils
+
+private[spark] class ExecutorPodsPollingSnapshotSource(
+ conf: SparkConf,
+ kubernetesClient: KubernetesClient,
+ snapshotsStore: ExecutorPodsSnapshotsStore,
+ pollingExecutor: ScheduledExecutorService) extends Logging {
+
+ private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL)
+
+ private var pollingFuture: Future[_] = _
+
+ def start(applicationId: String): Unit = {
+ require(pollingFuture == null, "Cannot start polling more than once.")
+ logDebug(s"Starting to check for executor pod state every $pollingInterval ms.")
+ pollingFuture = pollingExecutor.scheduleWithFixedDelay(
+ new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS)
+ }
+
+ def stop(): Unit = {
+ if (pollingFuture != null) {
+ pollingFuture.cancel(true)
+ pollingFuture = null
+ }
+ ThreadUtils.shutdown(pollingExecutor)
+ }
+
+ private class PollRunnable(applicationId: String) extends Runnable {
+ override def run(): Unit = {
+ logDebug(s"Resynchronizing full executor pod state from Kubernetes.")
+ snapshotsStore.replaceSnapshot(kubernetesClient
+ .pods()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId)
+ .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .list()
+ .getItems
+ .asScala)
+ }
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala
new file mode 100644
index 0000000000000..26be918043412
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.Pod
+
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.internal.Logging
+
+/**
+ * An immutable view of the current executor pods that are running in the cluster.
+ */
+private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) {
+
+ import ExecutorPodsSnapshot._
+
+ def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = {
+ val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod))
+ new ExecutorPodsSnapshot(newExecutorPods)
+ }
+}
+
+object ExecutorPodsSnapshot extends Logging {
+
+ def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = {
+ ExecutorPodsSnapshot(toStatesByExecutorId(executorPods))
+ }
+
+ def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState])
+
+ private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = {
+ executorPods.map { pod =>
+ (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod))
+ }.toMap
+ }
+
+ private def toState(pod: Pod): ExecutorPodState = {
+ if (isDeleted(pod)) {
+ PodDeleted(pod)
+ } else {
+ val phase = pod.getStatus.getPhase.toLowerCase
+ phase match {
+ case "pending" =>
+ PodPending(pod)
+ case "running" =>
+ PodRunning(pod)
+ case "failed" =>
+ PodFailed(pod)
+ case "succeeded" =>
+ PodSucceeded(pod)
+ case _ =>
+ logWarning(s"Received unknown phase $phase for executor pod with name" +
+ s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}")
+ PodUnknown(pod)
+ }
+ }
+ }
+
+ private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala
similarity index 66%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala
rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala
index 17614e040e587..dd264332cf9e8 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala
@@ -14,17 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.submit.steps
+package org.apache.spark.scheduler.cluster.k8s
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
+import io.fabric8.kubernetes.api.model.Pod
-/**
- * Represents a step in configuring the Spark driver pod.
- */
-private[spark] trait DriverConfigurationStep {
+private[spark] trait ExecutorPodsSnapshotsStore {
+
+ def addSubscriber
+ (processBatchIntervalMillis: Long)
+ (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit)
+
+ def stop(): Unit
+
+ def updatePod(updatedPod: Pod): Unit
- /**
- * Apply some transformation to the previous state of the driver to add a new feature to it.
- */
- def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec
+ def replaceSnapshot(newSnapshot: Seq[Pod]): Unit
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala
new file mode 100644
index 0000000000000..5583b4617eeb2
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent._
+
+import io.fabric8.kubernetes.api.model.Pod
+import javax.annotation.concurrent.GuardedBy
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Controls the propagation of the Spark application's executor pods state to subscribers that
+ * react to that state.
+ *
+ * Roughly follows a producer-consumer model. Producers report states of executor pods, and these
+ * states are then published to consumers that can perform any actions in response to these states.
+ *
+ * Producers push updates in one of two ways. An incremental update sent by updatePod() represents
+ * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that
+ * the passed pods are all of the most up to date states of all executor pods for the application.
+ * The combination of the states of all executor pods for the application is collectively known as
+ * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that
+ * most recent snapshot - either by incrementally updating the snapshot with a single new pod state,
+ * or by replacing the snapshot entirely on a full sync.
+ *
+ * Consumers, or subscribers, register that they want to be informed about all snapshots of the
+ * executor pods. Every time the store replaces its most up to date snapshot from either an
+ * incremental update or a full sync, the most recent snapshot after the update is posted to the
+ * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in
+ * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different
+ * time intervals.
+ */
+private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService)
+ extends ExecutorPodsSnapshotsStore {
+
+ private val SNAPSHOT_LOCK = new Object()
+
+ private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber]
+ private val pollingTasks = mutable.Buffer.empty[Future[_]]
+
+ @GuardedBy("SNAPSHOT_LOCK")
+ private var currentSnapshot = ExecutorPodsSnapshot()
+
+ override def addSubscriber(
+ processBatchIntervalMillis: Long)
+ (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = {
+ val newSubscriber = SnapshotsSubscriber(
+ new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots)
+ SNAPSHOT_LOCK.synchronized {
+ newSubscriber.snapshotsBuffer.add(currentSnapshot)
+ }
+ subscribers += newSubscriber
+ pollingTasks += subscribersExecutor.scheduleWithFixedDelay(
+ toRunnable(() => callSubscriber(newSubscriber)),
+ 0L,
+ processBatchIntervalMillis,
+ TimeUnit.MILLISECONDS)
+ }
+
+ override def stop(): Unit = {
+ pollingTasks.foreach(_.cancel(true))
+ ThreadUtils.shutdown(subscribersExecutor)
+ }
+
+ override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized {
+ currentSnapshot = currentSnapshot.withUpdate(updatedPod)
+ addCurrentSnapshotToSubscribers()
+ }
+
+ override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized {
+ currentSnapshot = ExecutorPodsSnapshot(newSnapshot)
+ addCurrentSnapshotToSubscribers()
+ }
+
+ private def addCurrentSnapshotToSubscribers(): Unit = {
+ subscribers.foreach { subscriber =>
+ subscriber.snapshotsBuffer.add(currentSnapshot)
+ }
+ }
+
+ private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = {
+ Utils.tryLogNonFatalError {
+ val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava
+ subscriber.snapshotsBuffer.drainTo(currentSnapshots)
+ subscriber.onNewSnapshots(currentSnapshots.asScala)
+ }
+ }
+
+ private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable {
+ override def run(): Unit = runnable()
+ }
+
+ private case class SnapshotsSubscriber(
+ snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot],
+ onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit)
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala
new file mode 100644
index 0000000000000..a6749a644e00c
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.io.Closeable
+
+import io.fabric8.kubernetes.api.model.Pod
+import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher}
+import io.fabric8.kubernetes.client.Watcher.Action
+
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+private[spark] class ExecutorPodsWatchSnapshotSource(
+ snapshotsStore: ExecutorPodsSnapshotsStore,
+ kubernetesClient: KubernetesClient) extends Logging {
+
+ private var watchConnection: Closeable = _
+
+ def start(applicationId: String): Unit = {
+ require(watchConnection == null, "Cannot start the watcher twice.")
+ logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," +
+ s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.")
+ watchConnection = kubernetesClient.pods()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId)
+ .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .watch(new ExecutorPodsWatcher())
+ }
+
+ def stop(): Unit = {
+ if (watchConnection != null) {
+ Utils.tryLogNonFatalError {
+ watchConnection.close()
+ }
+ watchConnection = null
+ }
+ }
+
+ private class ExecutorPodsWatcher extends Watcher[Pod] {
+ override def eventReceived(action: Action, pod: Pod): Unit = {
+ val podName = pod.getMetadata.getName
+ logDebug(s"Received executor pod update for pod named $podName, action $action")
+ snapshotsStore.updatePod(pod)
+ }
+
+ override def onClose(e: KubernetesClientException): Unit = {
+ logWarning("Kubernetes client has been closed (this is expected if the application is" +
+ " shutting down.)", e)
+ }
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
index a942db6ae02db..c6e931a38405f 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
@@ -17,23 +17,27 @@
package org.apache.spark.scheduler.cluster.k8s
import java.io.File
+import java.util.concurrent.TimeUnit
+import com.google.common.cache.CacheBuilder
import io.fabric8.kubernetes.client.Config
import org.apache.spark.{SparkContext, SparkException}
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory}
+import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{SystemClock, ThreadUtils}
private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging {
override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s")
override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = {
- if (masterURL.startsWith("k8s") && sc.deployMode == "client") {
+ if (masterURL.startsWith("k8s") &&
+ sc.deployMode == "client" &&
+ !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) {
throw new SparkException("Client mode is currently not supported for Kubernetes.")
}
@@ -44,86 +48,55 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit
sc: SparkContext,
masterURL: String,
scheduler: TaskScheduler): SchedulerBackend = {
- val sparkConf = sc.getConf
- val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME)
- val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF)
-
- if (initContainerConfigMap.isEmpty) {
- logWarning("The executor's init-container config map is not specified. Executors will " +
- "therefore not attempt to fetch remote or submitted dependencies.")
- }
-
- if (initContainerConfigMapKey.isEmpty) {
- logWarning("The executor's init-container config map key is not specified. Executors will " +
- "therefore not attempt to fetch remote or submitted dependencies.")
- }
-
- // Only set up the bootstrap if they've provided both the config map key and the config map
- // name. The config map might not be provided if init-containers aren't being used to
- // bootstrap dependencies.
- val initContainerBootstrap = for {
- configMap <- initContainerConfigMap
- configMapKey <- initContainerConfigMapKey
- } yield {
- val initContainerImage = sparkConf
- .get(INIT_CONTAINER_IMAGE)
- .getOrElse(throw new SparkException(
- "Must specify the init-container image when there are remote dependencies"))
- new InitContainerBootstrap(
- initContainerImage,
- sparkConf.get(CONTAINER_IMAGE_PULL_POLICY),
- sparkConf.get(JARS_DOWNLOAD_LOCATION),
- sparkConf.get(FILES_DOWNLOAD_LOCATION),
- configMap,
- configMapKey,
- SPARK_POD_EXECUTOR_ROLE,
- sparkConf)
- }
-
val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs(
- sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX)
- val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) {
- Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths))
- } else {
- None
- }
- // Mount user-specified executor secrets also into the executor's init-container. The
- // init-container may need credentials in the secrets to be able to download remote
- // dependencies. The executor's main container and its init-container share the secrets
- // because the init-container is sort of an implementation details and this sharing
- // avoids introducing a dedicated configuration property just for the init-container.
- val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty &&
- executorSecretNamesToMountPaths.nonEmpty) {
- Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths))
- } else {
- None
- }
-
+ sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX)
val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient(
KUBERNETES_MASTER_INTERNAL_URL,
- Some(sparkConf.get(KUBERNETES_NAMESPACE)),
+ Some(sc.conf.get(KUBERNETES_NAMESPACE)),
KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX,
- sparkConf,
+ sc.conf,
Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)),
Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH)))
- val executorPodFactory = new ExecutorPodFactory(
- sparkConf,
- mountSecretBootstrap,
- initContainerBootstrap,
- initContainerMountSecretsBootstrap)
-
- val allocatorExecutor = ThreadUtils
- .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator")
val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool(
"kubernetes-executor-requests")
+
+ val subscribersExecutor = ThreadUtils
+ .newDaemonThreadPoolScheduledExecutor(
+ "kubernetes-executor-snapshots-subscribers", 2)
+ val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor)
+ val removedExecutorsCache = CacheBuilder.newBuilder()
+ .expireAfterWrite(3, TimeUnit.MINUTES)
+ .build[java.lang.Long, java.lang.Long]()
+ val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager(
+ sc.conf,
+ new KubernetesExecutorBuilder(),
+ kubernetesClient,
+ snapshotsStore,
+ removedExecutorsCache)
+
+ val executorPodsAllocator = new ExecutorPodsAllocator(
+ sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock())
+
+ val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource(
+ snapshotsStore,
+ kubernetesClient)
+
+ val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ "kubernetes-executor-pod-polling-sync")
+ val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource(
+ sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor)
+
new KubernetesClusterSchedulerBackend(
scheduler.asInstanceOf[TaskSchedulerImpl],
sc.env.rpcEnv,
- executorPodFactory,
kubernetesClient,
- allocatorExecutor,
- requestExecutorsService)
+ requestExecutorsService,
+ snapshotsStore,
+ executorPodsAllocator,
+ executorPodsLifecycleEventHandler,
+ podsWatchEventSource,
+ podsPollingEventSource)
}
override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
index 9de4b16c30d3c..fa6dc2c479bbf 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
@@ -16,59 +16,32 @@
*/
package org.apache.spark.scheduler.cluster.k8s
-import java.io.Closeable
-import java.net.InetAddress
-import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit}
-import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference}
-import javax.annotation.concurrent.GuardedBy
+import java.util.concurrent.ExecutorService
-import io.fabric8.kubernetes.api.model._
-import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher}
-import io.fabric8.kubernetes.client.Watcher.Action
-import scala.collection.JavaConverters._
-import scala.collection.mutable
+import io.fabric8.kubernetes.client.KubernetesClient
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.SparkException
-import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv}
-import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl}
+import org.apache.spark.rpc.{RpcAddress, RpcEnv}
+import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
private[spark] class KubernetesClusterSchedulerBackend(
scheduler: TaskSchedulerImpl,
rpcEnv: RpcEnv,
- executorPodFactory: ExecutorPodFactory,
kubernetesClient: KubernetesClient,
- allocatorExecutor: ScheduledExecutorService,
- requestExecutorsService: ExecutorService)
+ requestExecutorsService: ExecutorService,
+ snapshotsStore: ExecutorPodsSnapshotsStore,
+ podAllocator: ExecutorPodsAllocator,
+ lifecycleEventHandler: ExecutorPodsLifecycleManager,
+ watchEvents: ExecutorPodsWatchSnapshotSource,
+ pollEvents: ExecutorPodsPollingSnapshotSource)
extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {
- import KubernetesClusterSchedulerBackend._
-
- private val EXECUTOR_ID_COUNTER = new AtomicLong(0L)
- private val RUNNING_EXECUTOR_PODS_LOCK = new Object
- @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK")
- private val runningExecutorsToPods = new mutable.HashMap[String, Pod]
- private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]()
- private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]()
- private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]()
-
- private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE)
-
- private val kubernetesDriverPodName = conf
- .get(KUBERNETES_DRIVER_POD_NAME)
- .getOrElse(throw new SparkException("Must specify the driver pod name"))
private implicit val requestExecutorContext = ExecutionContext.fromExecutorService(
requestExecutorsService)
- private val driverPod = kubernetesClient.pods()
- .inNamespace(kubernetesNamespace)
- .withName(kubernetesDriverPodName)
- .get()
-
protected override val minRegisteredRatio =
if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
0.8
@@ -76,367 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend(
super.minRegisteredRatio
}
- private val executorWatchResource = new AtomicReference[Closeable]
- private val totalExpectedExecutors = new AtomicInteger(0)
-
- private val driverUrl = RpcEndpointAddress(
- conf.get("spark.driver.host"),
- conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT),
- CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
-
private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf)
- private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY)
-
- private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE)
-
- private val executorLostReasonCheckMaxAttempts = conf.get(
- KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS)
-
- private val allocatorRunnable = new Runnable {
-
- // Maintains a map of executor id to count of checks performed to learn the loss reason
- // for an executor.
- private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int]
-
- override def run(): Unit = {
- handleDisconnectedExecutors()
-
- val executorsToAllocate = mutable.Map[String, Pod]()
- val currentTotalRegisteredExecutors = totalRegisteredExecutors.get
- val currentTotalExpectedExecutors = totalExpectedExecutors.get
- val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts()
- RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) {
- logDebug("Waiting for pending executors before scaling")
- } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) {
- logDebug("Maximum allowed executor limit reached. Not scaling up further.")
- } else {
- for (_ <- 0 until math.min(
- currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) {
- val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString
- val executorPod = executorPodFactory.createExecutorPod(
- executorId,
- applicationId(),
- driverUrl,
- conf.getExecutorEnv,
- driverPod,
- currentNodeToLocalTaskCount)
- executorsToAllocate(executorId) = executorPod
- logInfo(
- s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}")
- }
- }
- }
-
- val allocatedExecutors = executorsToAllocate.mapValues { pod =>
- Utils.tryLog {
- kubernetesClient.pods().create(pod)
- }
- }
-
- RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- allocatedExecutors.map {
- case (executorId, attemptedAllocatedExecutor) =>
- attemptedAllocatedExecutor.map { successfullyAllocatedExecutor =>
- runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor)
- }
- }
- }
- }
-
- def handleDisconnectedExecutors(): Unit = {
- // For each disconnected executor, synchronize with the loss reasons that may have been found
- // by the executor pod watcher. If the loss reason was discovered by the watcher,
- // inform the parent class with removeExecutor.
- disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach {
- case (executorId, executorPod) =>
- val knownExitReason = Option(podsWithKnownExitReasons.remove(
- executorPod.getMetadata.getName))
- knownExitReason.fold {
- removeExecutorOrIncrementLossReasonCheckCount(executorId)
- } { executorExited =>
- logWarning(s"Removing executor $executorId with loss reason " + executorExited.message)
- removeExecutor(executorId, executorExited)
- // We don't delete the pod running the executor that has an exit condition caused by
- // the application from the Kubernetes API server. This allows users to debug later on
- // through commands such as "kubectl logs " and
- // "kubectl describe pod ". Note that exited containers have terminated and
- // therefore won't take CPU and memory resources.
- // Otherwise, the executor pod is marked to be deleted from the API server.
- if (executorExited.exitCausedByApp) {
- logInfo(s"Executor $executorId exited because of the application.")
- deleteExecutorFromDataStructures(executorId)
- } else {
- logInfo(s"Executor $executorId failed because of a framework error.")
- deleteExecutorFromClusterAndDataStructures(executorId)
- }
- }
- }
- }
-
- def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = {
- val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0)
- if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) {
- removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons."))
- deleteExecutorFromClusterAndDataStructures(executorId)
- } else {
- executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1)
- }
- }
-
- def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = {
- deleteExecutorFromDataStructures(executorId).foreach { pod =>
- kubernetesClient.pods().delete(pod)
- }
- }
-
- def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = {
- disconnectedPodsByExecutorIdPendingRemoval.remove(executorId)
- executorReasonCheckAttemptCounts -= executorId
- podsWithKnownExitReasons.remove(executorId)
- RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- runningExecutorsToPods.remove(executorId).orElse {
- logWarning(s"Unable to remove pod for unknown executor $executorId")
- None
- }
- }
- }
- }
-
- override def sufficientResourcesRegistered(): Boolean = {
- totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio
+ // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler
+ private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
+ removeExecutor(executorId, reason)
}
override def start(): Unit = {
super.start()
- executorWatchResource.set(
- kubernetesClient
- .pods()
- .withLabel(SPARK_APP_ID_LABEL, applicationId())
- .watch(new ExecutorPodsWatcher()))
-
- allocatorExecutor.scheduleWithFixedDelay(
- allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS)
-
if (!Utils.isDynamicAllocationEnabled(conf)) {
- doRequestTotalExecutors(initialExecutors)
+ podAllocator.setTotalExpectedExecutors(initialExecutors)
}
+ lifecycleEventHandler.start(this)
+ podAllocator.start(applicationId())
+ watchEvents.start(applicationId())
+ pollEvents.start(applicationId())
}
override def stop(): Unit = {
- // stop allocation of new resources and caches.
- allocatorExecutor.shutdown()
- allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS)
-
- // send stop message to executors so they shut down cleanly
super.stop()
- try {
- val resource = executorWatchResource.getAndSet(null)
- if (resource != null) {
- resource.close()
- }
- } catch {
- case e: Throwable => logWarning("Failed to close the executor pod watcher", e)
+ Utils.tryLogNonFatalError {
+ snapshotsStore.stop()
}
- // then delete the executor pods
Utils.tryLogNonFatalError {
- deleteExecutorPodsOnStop()
- executorPodsByIPs.clear()
+ watchEvents.stop()
}
+
Utils.tryLogNonFatalError {
- logInfo("Closing kubernetes client")
- kubernetesClient.close()
+ pollEvents.stop()
+ }
+
+ Utils.tryLogNonFatalError {
+ kubernetesClient.pods()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId())
+ .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .delete()
}
- }
- /**
- * @return A map of K8s cluster nodes to the number of tasks that could benefit from data
- * locality if an executor launches on the cluster node.
- */
- private def getNodesWithLocalTaskCounts() : Map[String, Int] = {
- val nodeToLocalTaskCount = synchronized {
- mutable.Map[String, Int]() ++ hostToLocalTaskCount
+ Utils.tryLogNonFatalError {
+ ThreadUtils.shutdown(requestExecutorsService)
}
- for (pod <- executorPodsByIPs.values().asScala) {
- // Remove cluster nodes that are running our executors already.
- // TODO: This prefers spreading out executors across nodes. In case users want
- // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut
- // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html
- nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty ||
- nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty ||
- nodeToLocalTaskCount.remove(
- InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty
+ Utils.tryLogNonFatalError {
+ kubernetesClient.close()
}
- nodeToLocalTaskCount.toMap[String, Int]
}
override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] {
- totalExpectedExecutors.set(requestedTotal)
+ // TODO when we support dynamic allocation, the pod allocator should be told to process the
+ // current snapshot in order to decrease/increase the number of executors accordingly.
+ podAllocator.setTotalExpectedExecutors(requestedTotal)
true
}
- override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] {
- val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- executorIds.flatMap { executorId =>
- runningExecutorsToPods.remove(executorId) match {
- case Some(pod) =>
- disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
- Some(pod)
-
- case None =>
- logWarning(s"Unable to remove pod for unknown executor $executorId")
- None
- }
- }
- }
-
- kubernetesClient.pods().delete(podsToDelete: _*)
- true
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio
}
- private def deleteExecutorPodsOnStop(): Unit = {
- val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*)
- runningExecutorsToPods.clear()
- runningExecutorPodsCopy
- }
- kubernetesClient.pods().delete(executorPodsToDelete: _*)
+ override def getExecutorIds(): Seq[String] = synchronized {
+ super.getExecutorIds()
}
- private class ExecutorPodsWatcher extends Watcher[Pod] {
-
- private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1
-
- override def eventReceived(action: Action, pod: Pod): Unit = {
- val podName = pod.getMetadata.getName
- val podIP = pod.getStatus.getPodIP
-
- action match {
- case Action.MODIFIED if (pod.getStatus.getPhase == "Running"
- && pod.getMetadata.getDeletionTimestamp == null) =>
- val clusterNodeName = pod.getSpec.getNodeName
- logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.")
- executorPodsByIPs.put(podIP, pod)
-
- case Action.DELETED | Action.ERROR =>
- val executorId = getExecutorId(pod)
- logDebug(s"Executor pod $podName at IP $podIP was at $action.")
- if (podIP != null) {
- executorPodsByIPs.remove(podIP)
- }
-
- val executorExitReason = if (action == Action.ERROR) {
- logWarning(s"Received error event of executor pod $podName. Reason: " +
- pod.getStatus.getReason)
- executorExitReasonOnError(pod)
- } else if (action == Action.DELETED) {
- logWarning(s"Received delete event of executor pod $podName. Reason: " +
- pod.getStatus.getReason)
- executorExitReasonOnDelete(pod)
- } else {
- throw new IllegalStateException(
- s"Unknown action that should only be DELETED or ERROR: $action")
- }
- podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason)
-
- if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) {
- log.warn(s"Executor with id $executorId was not marked as disconnected, but the " +
- s"watch received an event of type $action for this executor. The executor may " +
- "have failed to start in the first place and never registered with the driver.")
- }
- disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
-
- case _ => logDebug(s"Received event of executor pod $podName: " + action)
- }
- }
-
- override def onClose(cause: KubernetesClientException): Unit = {
- logDebug("Executor pod watch closed.", cause)
- }
-
- private def getExecutorExitStatus(pod: Pod): Int = {
- val containerStatuses = pod.getStatus.getContainerStatuses
- if (!containerStatuses.isEmpty) {
- // we assume the first container represents the pod status. This assumption may not hold
- // true in the future. Revisit this if side-car containers start running inside executor
- // pods.
- getExecutorExitStatus(containerStatuses.get(0))
- } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS
- }
-
- private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = {
- Option(containerStatus.getState).map { containerState =>
- Option(containerState.getTerminated).map { containerStateTerminated =>
- containerStateTerminated.getExitCode.intValue()
- }.getOrElse(UNKNOWN_EXIT_CODE)
- }.getOrElse(UNKNOWN_EXIT_CODE)
- }
-
- private def isPodAlreadyReleased(pod: Pod): Boolean = {
- val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL)
- RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- !runningExecutorsToPods.contains(executorId)
- }
- }
-
- private def executorExitReasonOnError(pod: Pod): ExecutorExited = {
- val containerExitStatus = getExecutorExitStatus(pod)
- // container was probably actively killed by the driver.
- if (isPodAlreadyReleased(pod)) {
- ExecutorExited(containerExitStatus, exitCausedByApp = false,
- s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " +
- "request.")
- } else {
- val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " +
- s"exited with exit status code $containerExitStatus."
- ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason)
- }
- }
-
- private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = {
- val exitMessage = if (isPodAlreadyReleased(pod)) {
- s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request."
- } else {
- s"Pod ${pod.getMetadata.getName} deleted or lost."
- }
- ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage)
- }
-
- private def getExecutorId(pod: Pod): String = {
- val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL)
- require(executorId != null, "Unexpected pod metadata; expected all executor pods " +
- s"to have label $SPARK_EXECUTOR_ID_LABEL.")
- executorId
- }
+ override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] {
+ kubernetesClient.pods()
+ .withLabel(SPARK_APP_ID_LABEL, applicationId())
+ .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*)
+ .delete()
+ // Don't do anything else - let event handling from the Kubernetes API do the Spark changes
}
override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
new KubernetesDriverEndpoint(rpcEnv, properties)
}
- private class KubernetesDriverEndpoint(
- rpcEnv: RpcEnv,
- sparkProperties: Seq[(String, String)])
+ private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends DriverEndpoint(rpcEnv, sparkProperties) {
override def onDisconnected(rpcAddress: RpcAddress): Unit = {
- addressToExecutorId.get(rpcAddress).foreach { executorId =>
- if (disableExecutor(executorId)) {
- RUNNING_EXECUTOR_PODS_LOCK.synchronized {
- runningExecutorsToPods.get(executorId).foreach { pod =>
- disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod)
- }
- }
- }
- }
+ // Don't do anything besides disabling the executor - allow the Kubernetes API events to
+ // drive the rest of the lifecycle decisions
+ // TODO what if we disconnect from a networking issue? Probably want to mark the executor
+ // to be deleted eventually.
+ addressToExecutorId.get(rpcAddress).foreach(disableExecutor)
}
}
-}
-private object KubernetesClusterSchedulerBackend {
- private val UNKNOWN_EXIT_CODE = -1
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
new file mode 100644
index 0000000000000..769a0a5a63047
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
+
+private[spark] class KubernetesExecutorBuilder(
+ provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep =
+ new BasicExecutorFeatureStep(_),
+ provideSecretsStep:
+ (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep =
+ new MountSecretsFeatureStep(_),
+ provideEnvSecretsStep:
+ (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) =
+ new EnvSecretsFeatureStep(_),
+ provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ => LocalDirsFeatureStep =
+ new LocalDirsFeatureStep(_)) {
+
+ def buildFromFeatures(
+ kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = {
+ val baseFeatures = Seq(
+ provideBasicStep(kubernetesConf),
+ provideLocalDirsStep(kubernetesConf))
+
+ val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
+ Some(provideSecretsStep(kubernetesConf)) } else None
+
+ val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
+ Some(provideEnvSecretsStep(kubernetesConf)) } else None
+
+ val allFeatures: Seq[KubernetesFeatureConfigStep] =
+ baseFeatures ++
+ maybeRoleSecretNamesStep.toSeq ++
+ maybeProvideSecretsStep.toSeq
+
+ var executorPod = SparkPod.initialPod()
+ for (feature <- allFeatures) {
+ executorPod = feature.configurePod(executorPod)
+ }
+ executorPod
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala
new file mode 100644
index 0000000000000..527fc6b0d8f87
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList}
+import io.fabric8.kubernetes.client.{Watch, Watcher}
+import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource}
+
+object Fabric8Aliases {
+ type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
+ type LABELED_PODS = FilterWatchListDeletable[
+ Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]]
+ type SINGLE_POD = PodResource[Pod, DoneablePod]
+ type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[
+ HasMetadata, Boolean]
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala
new file mode 100644
index 0000000000000..661f942435921
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.k8s
+
+import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit._
+
+class KubernetesConfSuite extends SparkFunSuite {
+
+ private val APP_NAME = "test-app"
+ private val RESOURCE_NAME_PREFIX = "prefix"
+ private val APP_ID = "test-id"
+ private val MAIN_CLASS = "test-class"
+ private val APP_ARGS = Array("arg1", "arg2")
+ private val CUSTOM_LABELS = Map(
+ "customLabel1Key" -> "customLabel1Value",
+ "customLabel2Key" -> "customLabel2Value")
+ private val CUSTOM_ANNOTATIONS = Map(
+ "customAnnotation1Key" -> "customAnnotation1Value",
+ "customAnnotation2Key" -> "customAnnotation2Value")
+ private val SECRET_NAMES_TO_MOUNT_PATHS = Map(
+ "secret1" -> "/mnt/secrets/secret1",
+ "secret2" -> "/mnt/secrets/secret2")
+ private val SECRET_ENV_VARS = Map(
+ "envName1" -> "name1:key1",
+ "envName2" -> "name2:key2")
+ private val CUSTOM_ENVS = Map(
+ "customEnvKey1" -> "customEnvValue1",
+ "customEnvKey2" -> "customEnvValue2")
+ private val DRIVER_POD = new PodBuilder().build()
+ private val EXECUTOR_ID = "executor-id"
+
+ test("Basic driver translated fields.") {
+ val sparkConf = new SparkConf(false)
+ val conf = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppResource = None,
+ MAIN_CLASS,
+ APP_ARGS,
+ maybePyFiles = None)
+ assert(conf.appId === APP_ID)
+ assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap)
+ assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX)
+ assert(conf.roleSpecificConf.appName === APP_NAME)
+ assert(conf.roleSpecificConf.mainAppResource.isEmpty)
+ assert(conf.roleSpecificConf.mainClass === MAIN_CLASS)
+ assert(conf.roleSpecificConf.appArgs === APP_ARGS)
+ }
+
+ test("Creating driver conf with and without the main app jar influences spark.jars") {
+ val sparkConf = new SparkConf(false)
+ .setJars(Seq("local:///opt/spark/jar1.jar"))
+ val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar"))
+ val kubernetesConfWithMainJar = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppJar,
+ MAIN_CLASS,
+ APP_ARGS,
+ maybePyFiles = None)
+ assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars")
+ .split(",")
+ === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar"))
+ val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppResource = None,
+ MAIN_CLASS,
+ APP_ARGS,
+ maybePyFiles = None)
+ assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",")
+ === Array("local:///opt/spark/jar1.jar"))
+ assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1)
+ }
+
+ test("Creating driver conf with a python primary file") {
+ val mainResourceFile = "local:///opt/spark/main.py"
+ val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py")
+ val sparkConf = new SparkConf(false)
+ .setJars(Seq("local:///opt/spark/jar1.jar"))
+ .set("spark.files", "local:///opt/spark/example4.py")
+ val mainAppResource = Some(PythonMainAppResource(mainResourceFile))
+ val kubernetesConfWithMainResource = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppResource,
+ MAIN_CLASS,
+ APP_ARGS,
+ Some(inputPyFiles.mkString(",")))
+ assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",")
+ === Array("local:///opt/spark/jar1.jar"))
+ assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4)
+ assert(kubernetesConfWithMainResource.sparkFiles
+ === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles)
+ }
+
+ test("Testing explicit setting of memory overhead on non-JVM tasks") {
+ val sparkConf = new SparkConf(false)
+ .set(MEMORY_OVERHEAD_FACTOR, 0.3)
+
+ val mainResourceFile = "local:///opt/spark/main.py"
+ val mainAppResource = Some(PythonMainAppResource(mainResourceFile))
+ val conf = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppResource,
+ MAIN_CLASS,
+ APP_ARGS,
+ None)
+ assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3)
+ }
+
+ test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") {
+ val sparkConf = new SparkConf(false)
+ .set(MEMORY_OVERHEAD_FACTOR, 0.3)
+ CUSTOM_LABELS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value)
+ }
+ CUSTOM_ANNOTATIONS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value)
+ }
+ SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value)
+ }
+ SECRET_ENV_VARS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value)
+ }
+ CUSTOM_ENVS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value)
+ }
+
+ val conf = KubernetesConf.createDriverConf(
+ sparkConf,
+ APP_NAME,
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ mainAppResource = None,
+ MAIN_CLASS,
+ APP_ARGS,
+ maybePyFiles = None)
+ assert(conf.roleLabels === Map(
+ SPARK_APP_ID_LABEL -> APP_ID,
+ SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++
+ CUSTOM_LABELS)
+ assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS)
+ assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS)
+ assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS)
+ assert(conf.roleEnvs === CUSTOM_ENVS)
+ assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3)
+ }
+
+ test("Basic executor translated fields.") {
+ val conf = KubernetesConf.createExecutorConf(
+ new SparkConf(false),
+ EXECUTOR_ID,
+ APP_ID,
+ DRIVER_POD)
+ assert(conf.roleSpecificConf.executorId === EXECUTOR_ID)
+ assert(conf.roleSpecificConf.driverPod === DRIVER_POD)
+ }
+
+ test("Image pull secrets.") {
+ val conf = KubernetesConf.createExecutorConf(
+ new SparkConf(false)
+ .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "),
+ EXECUTOR_ID,
+ APP_ID,
+ DRIVER_POD)
+ assert(conf.imagePullSecrets() ===
+ Seq(
+ new LocalObjectReferenceBuilder().withName("my-secret-1").build(),
+ new LocalObjectReferenceBuilder().withName("my-secret-2").build()))
+ }
+
+ test("Set executor labels, annotations, and secrets") {
+ val sparkConf = new SparkConf(false)
+ CUSTOM_LABELS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value)
+ }
+ CUSTOM_ANNOTATIONS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value)
+ }
+ SECRET_ENV_VARS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value)
+ }
+ SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) =>
+ sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value)
+ }
+
+ val conf = KubernetesConf.createExecutorConf(
+ sparkConf,
+ EXECUTOR_ID,
+ APP_ID,
+ DRIVER_POD)
+ assert(conf.roleLabels === Map(
+ SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID,
+ SPARK_APP_ID_LABEL -> APP_ID,
+ SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS)
+ assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS)
+ assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS)
+ assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala
deleted file mode 100644
index e0f29ecd0fb53..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s
-
-import java.io.File
-import java.util.UUID
-
-import com.google.common.base.Charsets
-import com.google.common.io.Files
-import org.mockito.Mockito
-import org.scalatest.BeforeAndAfter
-import org.scalatest.mockito.MockitoSugar._
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.util.Utils
-
-class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter {
-
- private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt")
- private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt")
-
- private var downloadJarsDir: File = _
- private var downloadFilesDir: File = _
- private var downloadJarsSecretValue: String = _
- private var downloadFilesSecretValue: String = _
- private var fileFetcher: FileFetcher = _
-
- override def beforeAll(): Unit = {
- downloadJarsSecretValue = Files.toString(
- new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8)
- downloadFilesSecretValue = Files.toString(
- new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8)
- }
-
- before {
- downloadJarsDir = Utils.createTempDir()
- downloadFilesDir = Utils.createTempDir()
- fileFetcher = mock[FileFetcher]
- }
-
- after {
- downloadJarsDir.delete()
- downloadFilesDir.delete()
- }
-
- test("Downloads from remote server should invoke the file fetcher") {
- val sparkConf = getSparkConfForRemoteFileDownloads
- val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher)
- initContainerUnderTest.run()
- Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir)
- Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir)
- Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir)
- }
-
- private def getSparkConfForRemoteFileDownloads: SparkConf = {
- new SparkConf(true)
- .set(INIT_CONTAINER_REMOTE_JARS,
- "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar")
- .set(INIT_CONTAINER_REMOTE_FILES,
- "http://localhost:9000/file.txt")
- .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath)
- .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath)
- }
-
- private def createTempFile(extension: String): String = {
- val dir = Utils.createTempDir()
- val file = new File(dir, s"${UUID.randomUUID().toString}.$extension")
- Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8)
- file.getAbsolutePath
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
new file mode 100644
index 0000000000000..04b909db9d9f3
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit.JavaMainAppResource
+import org.apache.spark.deploy.k8s.submit.PythonMainAppResource
+
+class BasicDriverFeatureStepSuite extends SparkFunSuite {
+
+ private val APP_ID = "spark-app-id"
+ private val RESOURCE_NAME_PREFIX = "spark"
+ private val DRIVER_LABELS = Map("labelkey" -> "labelvalue")
+ private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent"
+ private val APP_NAME = "spark-test"
+ private val MAIN_CLASS = "org.apache.spark.examples.SparkPi"
+ private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner"
+ private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"")
+ private val CUSTOM_ANNOTATION_KEY = "customAnnotation"
+ private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue"
+ private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE)
+ private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1"
+ private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2"
+ private val DRIVER_ENVS = Map(
+ DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1,
+ DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2)
+ private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2")
+ private val TEST_IMAGE_PULL_SECRET_OBJECTS =
+ TEST_IMAGE_PULL_SECRETS.map { secret =>
+ new LocalObjectReferenceBuilder().withName(secret).build()
+ }
+
+ test("Check the pod respects all configurations from the user.") {
+ val sparkConf = new SparkConf()
+ .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod")
+ .set("spark.driver.cores", "2")
+ .set(KUBERNETES_DRIVER_LIMIT_CORES, "4")
+ .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M")
+ .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L)
+ .set(CONTAINER_IMAGE, "spark-driver:latest")
+ .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(","))
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ Some(JavaMainAppResource("")),
+ APP_NAME,
+ MAIN_CLASS,
+ APP_ARGS),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ DRIVER_LABELS,
+ DRIVER_ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ DRIVER_ENVS,
+ Seq.empty[String])
+
+ val featureStep = new BasicDriverFeatureStep(kubernetesConf)
+ val basePod = SparkPod.initialPod()
+ val configuredPod = featureStep.configurePod(basePod)
+
+ assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME)
+ assert(configuredPod.container.getImage === "spark-driver:latest")
+ assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY)
+
+ assert(configuredPod.container.getEnv.size === 3)
+ val envs = configuredPod.container
+ .getEnv
+ .asScala
+ .map(env => (env.getName, env.getValue))
+ .toMap
+ assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1))
+ assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2))
+
+ assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala ===
+ TEST_IMAGE_PULL_SECRET_OBJECTS)
+
+ assert(configuredPod.container.getEnv.asScala.exists(envVar =>
+ envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) &&
+ envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") &&
+ envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP")))
+
+ val resourceRequirements = configuredPod.container.getResources
+ val requests = resourceRequirements.getRequests.asScala
+ assert(requests("cpu").getAmount === "2")
+ assert(requests("memory").getAmount === "456Mi")
+ val limits = resourceRequirements.getLimits.asScala
+ assert(limits("memory").getAmount === "456Mi")
+ assert(limits("cpu").getAmount === "4")
+
+ val driverPodMetadata = configuredPod.pod.getMetadata
+ assert(driverPodMetadata.getName === "spark-driver-pod")
+ assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS)
+ assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS)
+ assert(configuredPod.pod.getSpec.getRestartPolicy === "Never")
+ val expectedSparkConf = Map(
+ KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod",
+ "spark.app.id" -> APP_ID,
+ KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX,
+ "spark.kubernetes.submitInDriver" -> "true")
+ assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf)
+ }
+
+ test("Check appropriate entrypoint rerouting for various bindings") {
+ val javaSparkConf = new SparkConf()
+ .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g")
+ .set(CONTAINER_IMAGE, "spark-driver:latest")
+ val pythonSparkConf = new SparkConf()
+ .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g")
+ .set(CONTAINER_IMAGE, "spark-driver:latest")
+ val javaKubernetesConf = KubernetesConf(
+ javaSparkConf,
+ KubernetesDriverSpecificConf(
+ Some(JavaMainAppResource("")),
+ APP_NAME,
+ PY_MAIN_CLASS,
+ APP_ARGS),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ DRIVER_LABELS,
+ DRIVER_ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ DRIVER_ENVS,
+ Seq.empty[String])
+ val pythonKubernetesConf = KubernetesConf(
+ pythonSparkConf,
+ KubernetesDriverSpecificConf(
+ Some(PythonMainAppResource("")),
+ APP_NAME,
+ PY_MAIN_CLASS,
+ APP_ARGS),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ DRIVER_LABELS,
+ DRIVER_ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ DRIVER_ENVS,
+ Seq.empty[String])
+ val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf)
+ val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf)
+ val basePod = SparkPod.initialPod()
+ val configuredJavaPod = javaFeatureStep.configurePod(basePod)
+ val configuredPythonPod = pythonFeatureStep.configurePod(basePod)
+ }
+
+ test("Additional system properties resolve jars and set cluster-mode confs.") {
+ val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar")
+ val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt")
+ val sparkConf = new SparkConf()
+ .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod")
+ .setJars(allJars)
+ .set("spark.files", allFiles.mkString(","))
+ .set(CONTAINER_IMAGE, "spark-driver:latest")
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ Some(JavaMainAppResource("")),
+ APP_NAME,
+ MAIN_CLASS,
+ APP_ARGS),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ DRIVER_LABELS,
+ DRIVER_ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ DRIVER_ENVS,
+ allFiles)
+ val step = new BasicDriverFeatureStep(kubernetesConf)
+ val additionalProperties = step.getAdditionalPodSystemProperties()
+ val expectedSparkConf = Map(
+ KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod",
+ "spark.app.id" -> APP_ID,
+ KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX,
+ "spark.kubernetes.submitInDriver" -> "true",
+ "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar",
+ "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt")
+ assert(additionalProperties === expectedSparkConf)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
new file mode 100644
index 0000000000000..f06030aa55c0c
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model._
+import org.mockito.MockitoAnnotations
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.rpc.RpcEndpointAddress
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+
+class BasicExecutorFeatureStepSuite
+ extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach {
+
+ private val APP_ID = "app-id"
+ private val DRIVER_HOSTNAME = "localhost"
+ private val DRIVER_PORT = 7098
+ private val DRIVER_ADDRESS = RpcEndpointAddress(
+ DRIVER_HOSTNAME,
+ DRIVER_PORT.toInt,
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
+ private val DRIVER_POD_NAME = "driver-pod"
+
+ private val DRIVER_POD_UID = "driver-uid"
+ private val RESOURCE_NAME_PREFIX = "base"
+ private val EXECUTOR_IMAGE = "executor-image"
+ private val LABELS = Map("label1key" -> "label1value")
+ private val ANNOTATIONS = Map("annotation1key" -> "annotation1value")
+ private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2")
+ private val TEST_IMAGE_PULL_SECRET_OBJECTS =
+ TEST_IMAGE_PULL_SECRETS.map { secret =>
+ new LocalObjectReferenceBuilder().withName(secret).build()
+ }
+ private val DRIVER_POD = new PodBuilder()
+ .withNewMetadata()
+ .withName(DRIVER_POD_NAME)
+ .withUid(DRIVER_POD_UID)
+ .endMetadata()
+ .withNewSpec()
+ .withNodeName("some-node")
+ .endSpec()
+ .withNewStatus()
+ .withHostIP("192.168.99.100")
+ .endStatus()
+ .build()
+ private var baseConf: SparkConf = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ baseConf = new SparkConf()
+ .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME)
+ .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX)
+ .set(CONTAINER_IMAGE, EXECUTOR_IMAGE)
+ .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true)
+ .set("spark.driver.host", DRIVER_HOSTNAME)
+ .set("spark.driver.port", DRIVER_PORT.toString)
+ .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(","))
+ }
+
+ test("basic executor pod has reasonable defaults") {
+ val step = new BasicExecutorFeatureStep(
+ KubernetesConf(
+ baseConf,
+ KubernetesExecutorSpecificConf("1", DRIVER_POD),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ LABELS,
+ ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]))
+ val executor = step.configurePod(SparkPod.initialPod())
+
+ // The executor pod name and default labels.
+ assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1")
+ assert(executor.pod.getMetadata.getLabels.asScala === LABELS)
+ assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS)
+
+ // There is exactly 1 container with no volume mounts and default memory limits.
+ // Default memory limit is 1024M + 384M (minimum overhead constant).
+ assert(executor.container.getImage === EXECUTOR_IMAGE)
+ assert(executor.container.getVolumeMounts.isEmpty)
+ assert(executor.container.getResources.getLimits.size() === 1)
+ assert(executor.container.getResources
+ .getLimits.get("memory").getAmount === "1408Mi")
+
+ // The pod has no node selector, volumes.
+ assert(executor.pod.getSpec.getNodeSelector.isEmpty)
+ assert(executor.pod.getSpec.getVolumes.isEmpty)
+
+ checkEnv(executor, Map())
+ checkOwnerReferences(executor.pod, DRIVER_POD_UID)
+ }
+
+ test("executor pod hostnames get truncated to 63 characters") {
+ val conf = baseConf.clone()
+ val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple"
+
+ val step = new BasicExecutorFeatureStep(
+ KubernetesConf(
+ conf,
+ KubernetesExecutorSpecificConf("1", DRIVER_POD),
+ longPodNamePrefix,
+ APP_ID,
+ LABELS,
+ ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]))
+ assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63)
+ }
+
+ test("classpath and extra java options get translated into environment variables") {
+ val conf = baseConf.clone()
+ conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar")
+ conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz")
+
+ val step = new BasicExecutorFeatureStep(
+ KubernetesConf(
+ conf,
+ KubernetesExecutorSpecificConf("1", DRIVER_POD),
+ RESOURCE_NAME_PREFIX,
+ APP_ID,
+ LABELS,
+ ANNOTATIONS,
+ Map.empty,
+ Map.empty,
+ Map("qux" -> "quux"),
+ Seq.empty[String]))
+ val executor = step.configurePod(SparkPod.initialPod())
+
+ checkEnv(executor,
+ Map("SPARK_JAVA_OPT_0" -> "foo=bar",
+ ENV_CLASSPATH -> "bar=baz",
+ "qux" -> "quux"))
+ checkOwnerReferences(executor.pod, DRIVER_POD_UID)
+ }
+
+ // There is always exactly one controller reference, and it points to the driver pod.
+ private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = {
+ assert(executor.getMetadata.getOwnerReferences.size() === 1)
+ assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid)
+ assert(executor.getMetadata.getOwnerReferences.get(0).getController === true)
+ }
+
+ // Check that the expected environment variables are present.
+ private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = {
+ val defaultEnvs = Map(
+ ENV_EXECUTOR_ID -> "1",
+ ENV_DRIVER_URL -> DRIVER_ADDRESS.toString,
+ ENV_EXECUTOR_CORES -> "1",
+ ENV_EXECUTOR_MEMORY -> "1g",
+ ENV_APPLICATION_ID -> APP_ID,
+ ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL,
+ ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars
+
+ assert(executorPod.container.getEnv.size() === defaultEnvs.size)
+ val mapEnvs = executorPod.container.getEnv.asScala.map {
+ x => (x.getName, x.getValue)
+ }.toMap
+ assert(defaultEnvs === mapEnvs)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
similarity index 66%
rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala
rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
index 64553d25883bb..7cea83591f3e8 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
@@ -14,34 +14,35 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.submit.steps
+package org.apache.spark.deploy.k8s.features
import java.io.File
-import scala.collection.JavaConverters._
-
import com.google.common.base.Charsets
import com.google.common.io.{BaseEncoding, Files}
import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret}
+import org.mockito.{Mock, MockitoAnnotations}
import org.scalatest.BeforeAndAfter
+import scala.collection.JavaConverters._
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
import org.apache.spark.util.Utils
-class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter {
+class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark"
+ private val APP_ID = "k8s-app"
private var credentialsTempDirectory: File = _
- private val BASE_DRIVER_SPEC = new KubernetesDriverSpec(
- driverPod = new PodBuilder().build(),
- driverContainer = new ContainerBuilder().build(),
- driverSparkConf = new SparkConf(false),
- otherKubernetesResources = Seq.empty[HasMetadata])
+ private val BASE_DRIVER_POD = SparkPod.initialPod()
+
+ @Mock
+ private var driverSpecificConf: KubernetesDriverSpecificConf = _
before {
+ MockitoAnnotations.initMocks(this)
credentialsTempDirectory = Utils.createTempDir()
}
@@ -50,13 +51,21 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA
}
test("Don't set any credentials") {
- val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep(
- new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX)
- val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC)
- assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod)
- assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer)
- assert(preparedDriverSpec.otherKubernetesResources.isEmpty)
- assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty)
+ val kubernetesConf = KubernetesConf(
+ new SparkConf(false),
+ driverSpecificConf,
+ KUBERNETES_RESOURCE_NAME_PREFIX,
+ APP_ID,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
+ assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD)
+ assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty)
+ assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty)
}
test("Only set credentials that are manually mounted.") {
@@ -73,14 +82,25 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA
.set(
s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX",
"/mnt/secrets/my-ca.pem")
+ val kubernetesConf = KubernetesConf(
+ submissionSparkConf,
+ driverSpecificConf,
+ KUBERNETES_RESOURCE_NAME_PREFIX,
+ APP_ID,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
- val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep(
- submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX)
- val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC)
- assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod)
- assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer)
- assert(preparedDriverSpec.otherKubernetesResources.isEmpty)
- assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap)
+ val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
+ assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD)
+ assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty)
+ val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties()
+ resolvedProperties.foreach { case (propKey, propValue) =>
+ assert(submissionSparkConf.get(propKey) === propValue)
+ }
}
test("Mount credentials from the submission client as a secret.") {
@@ -100,10 +120,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA
.set(
s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX",
caCertFile.getAbsolutePath)
- val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep(
- submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX)
- val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(
- BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf))
+ val kubernetesConf = KubernetesConf(
+ submissionSparkConf,
+ driverSpecificConf,
+ KUBERNETES_RESOURCE_NAME_PREFIX,
+ APP_ID,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
+ val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties()
val expectedSparkConf = Map(
s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "",
s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" ->
@@ -113,16 +142,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA
s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" ->
DRIVER_CREDENTIALS_CLIENT_CERT_PATH,
s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" ->
- DRIVER_CREDENTIALS_CA_CERT_PATH,
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" ->
- clientKeyFile.getAbsolutePath,
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" ->
- clientCertFile.getAbsolutePath,
- s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" ->
- caCertFile.getAbsolutePath)
- assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf)
- assert(preparedDriverSpec.otherKubernetesResources.size === 1)
- val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret]
+ DRIVER_CREDENTIALS_CA_CERT_PATH)
+ assert(resolvedProperties === expectedSparkConf)
+ assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1)
+ val credentialsSecret = kubernetesCredentialsStep
+ .getAdditionalKubernetesResources()
+ .head
+ .asInstanceOf[Secret]
assert(credentialsSecret.getMetadata.getName ===
s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials")
val decodedSecretData = credentialsSecret.getData.asScala.map { data =>
@@ -134,12 +160,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA
DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key",
DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert")
assert(decodedSecretData === expectedSecretData)
- val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala
+ val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD)
+ val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala
assert(driverPodVolumes.size === 1)
assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
assert(driverPodVolumes.head.getSecret != null)
assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName)
- val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala
+ val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala
assert(driverContainerVolumeMount.size === 1)
assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME)
assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala
new file mode 100644
index 0000000000000..77d38bf19cd10
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.Service
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Mockito.when
+import org.scalatest.BeforeAndAfter
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.util.Clock
+
+class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private val SHORT_RESOURCE_NAME_PREFIX =
+ "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH -
+ DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length)
+
+ private val LONG_RESOURCE_NAME_PREFIX =
+ "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH -
+ DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1)
+ private val DRIVER_LABELS = Map(
+ "label1key" -> "label1value",
+ "label2key" -> "label2value")
+
+ @Mock
+ private var clock: Clock = _
+
+ private var sparkConf: SparkConf = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ sparkConf = new SparkConf(false)
+ }
+
+ test("Headless service has a port for the driver RPC and the block manager.") {
+ sparkConf = sparkConf
+ .set("spark.driver.port", "9000")
+ .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080)
+ val configurationStep = new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ SHORT_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]))
+ assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod())
+ assert(configurationStep.getAdditionalKubernetesResources().size === 1)
+ assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service])
+ val driverService = configurationStep
+ .getAdditionalKubernetesResources()
+ .head
+ .asInstanceOf[Service]
+ verifyService(
+ 9000,
+ 8080,
+ s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}",
+ driverService)
+ }
+
+ test("Hostname and ports are set according to the service name.") {
+ val configurationStep = new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf
+ .set("spark.driver.port", "9000")
+ .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080)
+ .set(KUBERNETES_NAMESPACE, "my-namespace"),
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ SHORT_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]))
+ val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX +
+ DriverServiceFeatureStep.DRIVER_SVC_POSTFIX
+ val expectedHostName = s"$expectedServiceName.my-namespace.svc"
+ val additionalProps = configurationStep.getAdditionalPodSystemProperties()
+ verifySparkConfHostNames(additionalProps, expectedHostName)
+ }
+
+ test("Ports should resolve to defaults in SparkConf and in the service.") {
+ val configurationStep = new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ SHORT_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]))
+ val resolvedService = configurationStep
+ .getAdditionalKubernetesResources()
+ .head
+ .asInstanceOf[Service]
+ verifyService(
+ DEFAULT_DRIVER_PORT,
+ DEFAULT_BLOCKMANAGER_PORT,
+ s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}",
+ resolvedService)
+ val additionalProps = configurationStep.getAdditionalPodSystemProperties()
+ assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString)
+ assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key)
+ === DEFAULT_BLOCKMANAGER_PORT.toString)
+ }
+
+ test("Long prefixes should switch to using a generated name.") {
+ when(clock.getTimeMillis()).thenReturn(10000)
+ val configurationStep = new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"),
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ LONG_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]),
+ clock)
+ val driverService = configurationStep
+ .getAdditionalKubernetesResources()
+ .head
+ .asInstanceOf[Service]
+ val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}"
+ assert(driverService.getMetadata.getName === expectedServiceName)
+ val expectedHostName = s"$expectedServiceName.my-namespace.svc"
+ val additionalProps = configurationStep.getAdditionalPodSystemProperties()
+ verifySparkConfHostNames(additionalProps, expectedHostName)
+ }
+
+ test("Disallow bind address and driver host to be set explicitly.") {
+ try {
+ new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"),
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ LONG_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]),
+ clock)
+ fail("The driver bind address should not be allowed.")
+ } catch {
+ case e: Throwable =>
+ assert(e.getMessage ===
+ s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" +
+ " not supported in Kubernetes mode, as the driver's bind address is managed" +
+ " and set to the driver pod's IP address.")
+ }
+ sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS)
+ sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host")
+ try {
+ new DriverServiceFeatureStep(
+ KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ None, "main", "app", Seq.empty),
+ LONG_RESOURCE_NAME_PREFIX,
+ "app-id",
+ DRIVER_LABELS,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String]),
+ clock)
+ fail("The driver host address should not be allowed.")
+ } catch {
+ case e: Throwable =>
+ assert(e.getMessage ===
+ s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" +
+ " not supported in Kubernetes mode, as the driver's hostname will be managed via" +
+ " a Kubernetes service.")
+ }
+ }
+
+ private def verifyService(
+ driverPort: Int,
+ blockManagerPort: Int,
+ expectedServiceName: String,
+ service: Service): Unit = {
+ assert(service.getMetadata.getName === expectedServiceName)
+ assert(service.getSpec.getClusterIP === "None")
+ assert(service.getSpec.getSelector.asScala === DRIVER_LABELS)
+ assert(service.getSpec.getPorts.size() === 2)
+ val driverServicePorts = service.getSpec.getPorts.asScala
+ assert(driverServicePorts.head.getName === DRIVER_PORT_NAME)
+ assert(driverServicePorts.head.getPort.intValue() === driverPort)
+ assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort)
+ assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME)
+ assert(driverServicePorts(1).getPort.intValue() === blockManagerPort)
+ assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort)
+ }
+
+ private def verifySparkConfHostNames(
+ driverSparkConf: Map[String, String], expectedHostName: String): Unit = {
+ assert(driverSparkConf(
+ org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala
new file mode 100644
index 0000000000000..af6b35eae484a
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.PodBuilder
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s._
+
+class EnvSecretsFeatureStepSuite extends SparkFunSuite{
+ private val KEY_REF_NAME_FOO = "foo"
+ private val KEY_REF_NAME_BAR = "bar"
+ private val KEY_REF_KEY_FOO = "key_foo"
+ private val KEY_REF_KEY_BAR = "key_bar"
+ private val ENV_NAME_FOO = "MY_FOO"
+ private val ENV_NAME_BAR = "MY_bar"
+
+ test("sets up all keyRefs") {
+ val baseDriverPod = SparkPod.initialPod()
+ val envVarsToKeys = Map(
+ ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}",
+ ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}")
+ val sparkConf = new SparkConf(false)
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesExecutorSpecificConf("1", new PodBuilder().build()),
+ "resource-name-prefix",
+ "app-id",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ envVarsToKeys,
+ Map.empty,
+ Seq.empty[String])
+
+ val step = new EnvSecretsFeatureStep(kubernetesConf)
+ val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container
+
+ val expectedVars =
+ Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}")
+
+ expectedVars.foreach { envName =>
+ assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName))
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala
new file mode 100644
index 0000000000000..f90380e30e52a
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import scala.collection.JavaConverters._
+
+import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder}
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark.deploy.k8s.SparkPod
+
+object KubernetesFeaturesTestUtils {
+
+ def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep](
+ stepType: String, stepClass: Class[T]): T = {
+ val mockStep = mock(stepClass)
+ when(mockStep.getAdditionalKubernetesResources()).thenReturn(
+ getSecretsForStepType(stepType))
+
+ when(mockStep.getAdditionalPodSystemProperties())
+ .thenReturn(Map(stepType -> stepType))
+ when(mockStep.configurePod(Matchers.any(classOf[SparkPod])))
+ .thenAnswer(new Answer[SparkPod]() {
+ override def answer(invocation: InvocationOnMock): SparkPod = {
+ val originalPod = invocation.getArgumentAt(0, classOf[SparkPod])
+ val configuredPod = new PodBuilder(originalPod.pod)
+ .editOrNewMetadata()
+ .addToLabels(stepType, stepType)
+ .endMetadata()
+ .build()
+ SparkPod(configuredPod, originalPod.container)
+ }
+ })
+ mockStep
+ }
+
+ def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String)
+ : Seq[HasMetadata] = {
+ Seq(new SecretBuilder()
+ .withNewMetadata()
+ .withName(stepType)
+ .endMetadata()
+ .build())
+ }
+
+ def containerHasEnvVar(container: Container, envVarName: String): Boolean = {
+ container.getEnv.asScala.exists(envVar => envVar.getName == envVarName)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
new file mode 100644
index 0000000000000..bd6ce4b42fc8e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder}
+import org.mockito.Mockito
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+
+class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
+ private val defaultLocalDir = "/var/data/default-local-dir"
+ private var sparkConf: SparkConf = _
+ private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _
+
+ before {
+ val realSparkConf = new SparkConf(false)
+ sparkConf = Mockito.spy(realSparkConf)
+ kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ None,
+ "app-name",
+ "main",
+ Seq.empty),
+ "resource",
+ "app-id",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ }
+
+ test("Resolve to default local dir if neither env nor configuration are set") {
+ Mockito.doReturn(null).when(sparkConf).get("spark.local.dir")
+ Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS")
+ val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir)
+ val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod())
+ assert(configuredPod.pod.getSpec.getVolumes.size === 1)
+ assert(configuredPod.pod.getSpec.getVolumes.get(0) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-1")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.container.getVolumeMounts.size === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-1")
+ .withMountPath(defaultLocalDir)
+ .build())
+ assert(configuredPod.container.getEnv.size === 1)
+ assert(configuredPod.container.getEnv.get(0) ===
+ new EnvVarBuilder()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue(defaultLocalDir)
+ .build())
+ }
+
+ test("Use configured local dirs split on comma if provided.") {
+ Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2")
+ .when(sparkConf).getenv("SPARK_LOCAL_DIRS")
+ val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir)
+ val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod())
+ assert(configuredPod.pod.getSpec.getVolumes.size === 2)
+ assert(configuredPod.pod.getSpec.getVolumes.get(0) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-1")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.pod.getSpec.getVolumes.get(1) ===
+ new VolumeBuilder()
+ .withName(s"spark-local-dir-2")
+ .withNewEmptyDir()
+ .endEmptyDir()
+ .build())
+ assert(configuredPod.container.getVolumeMounts.size === 2)
+ assert(configuredPod.container.getVolumeMounts.get(0) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-1")
+ .withMountPath("/var/data/my-local-dir-1")
+ .build())
+ assert(configuredPod.container.getVolumeMounts.get(1) ===
+ new VolumeMountBuilder()
+ .withName(s"spark-local-dir-2")
+ .withMountPath("/var/data/my-local-dir-2")
+ .build())
+ assert(configuredPod.container.getEnv.size === 1)
+ assert(configuredPod.container.getEnv.get(0) ===
+ new EnvVarBuilder()
+ .withName("SPARK_LOCAL_DIRS")
+ .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2")
+ .build())
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
similarity index 63%
rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala
rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
index 960d0bda1d011..eff75b8a15daa 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
@@ -14,29 +14,40 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.submit.steps
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model.PodBuilder
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils}
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod}
-class DriverMountSecretsStepSuite extends SparkFunSuite {
+class MountSecretsFeatureStepSuite extends SparkFunSuite {
private val SECRET_FOO = "foo"
private val SECRET_BAR = "bar"
private val SECRET_MOUNT_PATH = "/etc/secrets/driver"
test("mounts all given secrets") {
- val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false))
+ val baseDriverPod = SparkPod.initialPod()
val secretNamesToMountPaths = Map(
SECRET_FOO -> SECRET_MOUNT_PATH,
SECRET_BAR -> SECRET_MOUNT_PATH)
+ val sparkConf = new SparkConf(false)
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesExecutorSpecificConf("1", new PodBuilder().build()),
+ "resource-name-prefix",
+ "app-id",
+ Map.empty,
+ Map.empty,
+ secretNamesToMountPaths,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
- val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths)
- val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap)
- val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec)
- val driverPodWithSecretsMounted = configuredDriverSpec.driverPod
- val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer
+ val step = new MountSecretsFeatureStep(kubernetesConf)
+ val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod
+ val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container
Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName =>
assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName))
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala
new file mode 100644
index 0000000000000..0f2bf2fa1d9b5
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features.bindings
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit.PythonMainAppResource
+
+class JavaDriverFeatureStepSuite extends SparkFunSuite {
+
+ test("Java Step modifies container correctly") {
+ val baseDriverPod = SparkPod.initialPod()
+ val sparkConf = new SparkConf(false)
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ Some(PythonMainAppResource("local:///main.jar")),
+ "test-class",
+ "java-runner",
+ Seq("5 7")),
+ appResourceNamePrefix = "",
+ appId = "",
+ roleLabels = Map.empty,
+ roleAnnotations = Map.empty,
+ roleSecretNamesToMountPaths = Map.empty,
+ roleSecretEnvNamesToKeyRefs = Map.empty,
+ roleEnvs = Map.empty,
+ sparkFiles = Seq.empty[String])
+
+ val step = new JavaDriverFeatureStep(kubernetesConf)
+ val driverPod = step.configurePod(baseDriverPod).pod
+ val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container
+ assert(driverContainerwithJavaStep.getArgs.size === 7)
+ val args = driverContainerwithJavaStep
+ .getArgs.asScala
+ assert(args === List(
+ "driver",
+ "--properties-file", SPARK_CONF_PATH,
+ "--class", "test-class",
+ "spark-internal", "5 7"))
+
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala
new file mode 100644
index 0000000000000..a1f9a5d9e264e
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features.bindings
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit.PythonMainAppResource
+
+class PythonDriverFeatureStepSuite extends SparkFunSuite {
+
+ test("Python Step modifies container correctly") {
+ val expectedMainResource = "/main.py"
+ val mainResource = "local:///main.py"
+ val pyFiles = Seq("local:///example2.py", "local:///example3.py")
+ val expectedPySparkFiles =
+ "/example2.py:/example3.py"
+ val baseDriverPod = SparkPod.initialPod()
+ val sparkConf = new SparkConf(false)
+ .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource)
+ .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(","))
+ .set("spark.files", "local:///example.py")
+ .set(PYSPARK_MAJOR_PYTHON_VERSION, "2")
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ Some(PythonMainAppResource("local:///main.py")),
+ "test-app",
+ "python-runner",
+ Seq("5 7")),
+ appResourceNamePrefix = "",
+ appId = "",
+ roleLabels = Map.empty,
+ roleAnnotations = Map.empty,
+ roleSecretNamesToMountPaths = Map.empty,
+ roleSecretEnvNamesToKeyRefs = Map.empty,
+ roleEnvs = Map.empty,
+ sparkFiles = Seq.empty[String])
+
+ val step = new PythonDriverFeatureStep(kubernetesConf)
+ val driverPod = step.configurePod(baseDriverPod).pod
+ val driverContainerwithPySpark = step.configurePod(baseDriverPod).container
+ assert(driverContainerwithPySpark.getEnv.size === 4)
+ val envs = driverContainerwithPySpark
+ .getEnv
+ .asScala
+ .map(env => (env.getName, env.getValue))
+ .toMap
+ assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource)
+ assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles)
+ assert(envs(ENV_PYSPARK_ARGS) === "5 7")
+ assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2")
+ }
+ test("Python Step testing empty pyfiles") {
+ val mainResource = "local:///main.py"
+ val baseDriverPod = SparkPod.initialPod()
+ val sparkConf = new SparkConf(false)
+ .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource)
+ .set(PYSPARK_MAJOR_PYTHON_VERSION, "3")
+ val kubernetesConf = KubernetesConf(
+ sparkConf,
+ KubernetesDriverSpecificConf(
+ Some(PythonMainAppResource("local:///main.py")),
+ "test-class-py",
+ "python-runner",
+ Seq.empty[String]),
+ appResourceNamePrefix = "",
+ appId = "",
+ roleLabels = Map.empty,
+ roleAnnotations = Map.empty,
+ roleSecretNamesToMountPaths = Map.empty,
+ roleSecretEnvNamesToKeyRefs = Map.empty,
+ roleEnvs = Map.empty,
+ sparkFiles = Seq.empty[String])
+ val step = new PythonDriverFeatureStep(kubernetesConf)
+ val driverContainerwithPySpark = step.configurePod(baseDriverPod).container
+ val args = driverContainerwithPySpark
+ .getArgs.asScala
+ assert(driverContainerwithPySpark.getArgs.size === 5)
+ assert(args === List(
+ "driver-py",
+ "--properties-file", SPARK_CONF_PATH,
+ "--class", "test-class-py"))
+ val envs = driverContainerwithPySpark
+ .getEnv
+ .asScala
+ .map(env => (env.getName, env.getValue))
+ .toMap
+ assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3")
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
index bf4ec04893204..d045d9ae89c07 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
@@ -16,38 +16,99 @@
*/
package org.apache.spark.deploy.k8s.submit
-import scala.collection.JavaConverters._
-
-import com.google.common.collect.Iterables
import io.fabric8.kubernetes.api.model._
import io.fabric8.kubernetes.client.{KubernetesClient, Watch}
import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource}
import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations}
import org.mockito.Mockito.{doReturn, verify, when}
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter
import org.scalatest.mockito.MockitoSugar._
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod}
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
class ClientSuite extends SparkFunSuite with BeforeAndAfter {
private val DRIVER_POD_UID = "pod-id"
private val DRIVER_POD_API_VERSION = "v1"
private val DRIVER_POD_KIND = "pod"
+ private val KUBERNETES_RESOURCE_PREFIX = "resource-example"
+ private val POD_NAME = "driver"
+ private val CONTAINER_NAME = "container"
+ private val APP_ID = "app-id"
+ private val APP_NAME = "app"
+ private val MAIN_CLASS = "main"
+ private val APP_ARGS = Seq("arg1", "arg2")
+ private val RESOLVED_JAVA_OPTIONS = Map(
+ "conf1key" -> "conf1value",
+ "conf2key" -> "conf2value")
+ private val BUILT_DRIVER_POD =
+ new PodBuilder()
+ .withNewMetadata()
+ .withName(POD_NAME)
+ .endMetadata()
+ .withNewSpec()
+ .withHostname("localhost")
+ .endSpec()
+ .build()
+ private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build()
+ private val ADDITIONAL_RESOURCES = Seq(
+ new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build())
+
+ private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec(
+ SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER),
+ ADDITIONAL_RESOURCES,
+ RESOLVED_JAVA_OPTIONS)
+
+ private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER)
+ .addNewEnv()
+ .withName(ENV_SPARK_CONF_DIR)
+ .withValue(SPARK_CONF_DIR_INTERNAL)
+ .endEnv()
+ .addNewVolumeMount()
+ .withName(SPARK_CONF_VOLUME)
+ .withMountPath(SPARK_CONF_DIR_INTERNAL)
+ .endVolumeMount()
+ .build()
+ private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD)
+ .editSpec()
+ .addToContainers(FULL_EXPECTED_CONTAINER)
+ .addNewVolume()
+ .withName(SPARK_CONF_VOLUME)
+ .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap()
+ .endVolume()
+ .endSpec()
+ .build()
+
+ private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD)
+ .editMetadata()
+ .withUid(DRIVER_POD_UID)
+ .endMetadata()
+ .withApiVersion(DRIVER_POD_API_VERSION)
+ .withKind(DRIVER_POD_KIND)
+ .build()
- private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[
- HasMetadata, Boolean]
- private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
+ private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret =>
+ new SecretBuilder(secret)
+ .editMetadata()
+ .addNewOwnerReference()
+ .withName(POD_NAME)
+ .withApiVersion(DRIVER_POD_API_VERSION)
+ .withKind(DRIVER_POD_KIND)
+ .withController(true)
+ .withUid(DRIVER_POD_UID)
+ .endOwnerReference()
+ .endMetadata()
+ .build()
+ }
@Mock
private var kubernetesClient: KubernetesClient = _
@Mock
- private var podOperations: Pods = _
+ private var podOperations: PODS = _
@Mock
private var namedPods: PodResource[Pod, DoneablePod] = _
@@ -56,179 +117,93 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter {
private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _
@Mock
- private var resourceList: ResourceList = _
+ private var driverBuilder: KubernetesDriverBuilder = _
- private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep)
+ @Mock
+ private var resourceList: RESOURCE_LIST = _
+
+ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _
+
+ private var sparkConf: SparkConf = _
private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _
private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _
before {
MockitoAnnotations.initMocks(this)
+ sparkConf = new SparkConf(false)
+ kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf](
+ sparkConf,
+ KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS),
+ KUBERNETES_RESOURCE_PREFIX,
+ APP_ID,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC)
when(kubernetesClient.pods()).thenReturn(podOperations)
- when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods)
+ when(podOperations.withName(POD_NAME)).thenReturn(namedPods)
createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod])
createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata])
- when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] {
- override def answer(invocation: InvocationOnMock): Pod = {
- new PodBuilder(invocation.getArgumentAt(0, classOf[Pod]))
- .editMetadata()
- .withUid(DRIVER_POD_UID)
- .endMetadata()
- .withApiVersion(DRIVER_POD_API_VERSION)
- .withKind(DRIVER_POD_KIND)
- .build()
- }
- })
- when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods)
+ when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE)
when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch])
doReturn(resourceList)
.when(kubernetesClient)
.resourceList(createdResourcesArgumentCaptor.capture())
}
- test("The client should configure the pod with the submission steps.") {
+ test("The client should configure the pod using the builder.") {
val submissionClient = new Client(
- submissionSteps,
- new SparkConf(false),
+ driverBuilder,
+ kubernetesConf,
kubernetesClient,
false,
"spark",
- loggingPodStatusWatcher)
+ loggingPodStatusWatcher,
+ KUBERNETES_RESOURCE_PREFIX)
submissionClient.run()
- val createdPod = createdPodArgumentCaptor.getValue
- assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName)
- assert(createdPod.getMetadata.getLabels.asScala ===
- Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue))
- assert(createdPod.getMetadata.getAnnotations.asScala ===
- Map(SecondTestConfigurationStep.annotationKey ->
- SecondTestConfigurationStep.annotationValue))
- assert(createdPod.getSpec.getContainers.size() === 1)
- assert(createdPod.getSpec.getContainers.get(0).getName ===
- SecondTestConfigurationStep.containerName)
+ verify(podOperations).create(FULL_EXPECTED_POD)
}
- test("The client should create the secondary Kubernetes resources.") {
+ test("The client should create Kubernetes resources") {
val submissionClient = new Client(
- submissionSteps,
- new SparkConf(false),
+ driverBuilder,
+ kubernetesConf,
kubernetesClient,
false,
"spark",
- loggingPodStatusWatcher)
+ loggingPodStatusWatcher,
+ KUBERNETES_RESOURCE_PREFIX)
submissionClient.run()
- val createdPod = createdPodArgumentCaptor.getValue
val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues
- assert(otherCreatedResources.size === 1)
- val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret]
- assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName)
- assert(createdResource.getData.asScala ===
- Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData))
- val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences)
- assert(ownerReference.getName === createdPod.getMetadata.getName)
- assert(ownerReference.getKind === DRIVER_POD_KIND)
- assert(ownerReference.getUid === DRIVER_POD_UID)
- assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION)
- }
-
- test("The client should attach the driver container with the appropriate JVM options.") {
- val sparkConf = new SparkConf(false)
- .set("spark.logConf", "true")
- .set(
- org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS,
- "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails")
- val submissionClient = new Client(
- submissionSteps,
- sparkConf,
- kubernetesClient,
- false,
- "spark",
- loggingPodStatusWatcher)
- submissionClient.run()
- val createdPod = createdPodArgumentCaptor.getValue
- val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers)
- assert(driverContainer.getName === SecondTestConfigurationStep.containerName)
- val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env =>
- env.getName.startsWith(ENV_JAVA_OPT_PREFIX)
- }.sortBy(_.getName)
- assert(driverJvmOptsEnvs.size === 4)
-
- val expectedJvmOptsValues = Seq(
- "-Dspark.logConf=true",
- s"-D${SecondTestConfigurationStep.sparkConfKey}=" +
- s"${SecondTestConfigurationStep.sparkConfValue}",
- "-XX:+HeapDumpOnOutOfMemoryError",
- "-XX:+PrintGCDetails")
- driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach {
- case ((resolvedEnv, expectedJvmOpt), index) =>
- assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index")
- assert(resolvedEnv.getValue === expectedJvmOpt)
- }
+ assert(otherCreatedResources.size === 2)
+ val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq
+ assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES)
+ val configMaps = otherCreatedResources.toArray
+ .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap])
+ assert(secrets.nonEmpty)
+ assert(configMaps.nonEmpty)
+ val configMap = configMaps.head
+ assert(configMap.getMetadata.getName ===
+ s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map")
+ assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME))
+ assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value"))
+ assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value"))
}
test("Waiting for app completion should stall on the watcher") {
val submissionClient = new Client(
- submissionSteps,
- new SparkConf(false),
+ driverBuilder,
+ kubernetesConf,
kubernetesClient,
true,
"spark",
- loggingPodStatusWatcher)
+ loggingPodStatusWatcher,
+ KUBERNETES_RESOURCE_PREFIX)
submissionClient.run()
verify(loggingPodStatusWatcher).awaitCompletion()
}
-
-}
-
-private object FirstTestConfigurationStep extends DriverConfigurationStep {
-
- val podName = "test-pod"
- val secretName = "test-secret"
- val labelKey = "first-submit"
- val labelValue = "true"
- val secretKey = "secretKey"
- val secretData = "secretData"
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val modifiedPod = new PodBuilder(driverSpec.driverPod)
- .editMetadata()
- .withName(podName)
- .addToLabels(labelKey, labelValue)
- .endMetadata()
- .build()
- val additionalResource = new SecretBuilder()
- .withNewMetadata()
- .withName(secretName)
- .endMetadata()
- .addToData(secretKey, secretData)
- .build()
- driverSpec.copy(
- driverPod = modifiedPod,
- otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource))
- }
-}
-
-private object SecondTestConfigurationStep extends DriverConfigurationStep {
-
- val annotationKey = "second-submit"
- val annotationValue = "submitted"
- val sparkConfKey = "spark.custom-conf"
- val sparkConfValue = "custom-conf-value"
- val containerName = "driverContainer"
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val modifiedPod = new PodBuilder(driverSpec.driverPod)
- .editMetadata()
- .addToAnnotations(annotationKey, annotationValue)
- .endMetadata()
- .build()
- val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue)
- val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer)
- .withName(containerName)
- .build()
- driverSpec.copy(
- driverPod = modifiedPod,
- driverSparkConf = resolvedSparkConf,
- driverContainer = modifiedContainer)
- }
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala
deleted file mode 100644
index 033d303e946fd..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit
-
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.submit.steps._
-
-class DriverConfigOrchestratorSuite extends SparkFunSuite {
-
- private val DRIVER_IMAGE = "driver-image"
- private val IC_IMAGE = "init-container-image"
- private val APP_ID = "spark-app-id"
- private val LAUNCH_TIME = 975256L
- private val APP_NAME = "spark"
- private val MAIN_CLASS = "org.apache.spark.examples.SparkPi"
- private val APP_ARGS = Array("arg1", "arg2")
- private val SECRET_FOO = "foo"
- private val SECRET_BAR = "bar"
- private val SECRET_MOUNT_PATH = "/etc/secrets/driver"
-
- test("Base submission steps with a main app resource.") {
- val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE)
- val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar")
- val orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Some(mainAppResource),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- validateStepTypes(
- orchestrator,
- classOf[BasicDriverConfigurationStep],
- classOf[DriverServiceBootstrapStep],
- classOf[DriverKubernetesCredentialsStep],
- classOf[DependencyResolutionStep]
- )
- }
-
- test("Base submission steps without a main app resource.") {
- val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE)
- val orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Option.empty,
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- validateStepTypes(
- orchestrator,
- classOf[BasicDriverConfigurationStep],
- classOf[DriverServiceBootstrapStep],
- classOf[DriverKubernetesCredentialsStep]
- )
- }
-
- test("Submission steps with an init-container.") {
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, DRIVER_IMAGE)
- .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE)
- .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar")
- val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar")
- val orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Some(mainAppResource),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- validateStepTypes(
- orchestrator,
- classOf[BasicDriverConfigurationStep],
- classOf[DriverServiceBootstrapStep],
- classOf[DriverKubernetesCredentialsStep],
- classOf[DependencyResolutionStep],
- classOf[DriverInitContainerBootstrapStep])
- }
-
- test("Submission steps with driver secrets to mount") {
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, DRIVER_IMAGE)
- .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH)
- .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH)
- val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar")
- val orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Some(mainAppResource),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- validateStepTypes(
- orchestrator,
- classOf[BasicDriverConfigurationStep],
- classOf[DriverServiceBootstrapStep],
- classOf[DriverKubernetesCredentialsStep],
- classOf[DependencyResolutionStep],
- classOf[DriverMountSecretsStep])
- }
-
- test("Submission using client local dependencies") {
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, DRIVER_IMAGE)
- var orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Some(JavaMainAppResource("file:///var/apps/jars/main.jar")),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- assertThrows[SparkException] {
- orchestrator.getAllConfigurationSteps
- }
-
- sparkConf.set("spark.files", "/path/to/file1,/path/to/file2")
- orchestrator = new DriverConfigOrchestrator(
- APP_ID,
- LAUNCH_TIME,
- Some(JavaMainAppResource("local:///var/apps/jars/main.jar")),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- assertThrows[SparkException] {
- orchestrator.getAllConfigurationSteps
- }
- }
-
- private def validateStepTypes(
- orchestrator: DriverConfigOrchestrator,
- types: Class[_ <: DriverConfigurationStep]*): Unit = {
- val steps = orchestrator.getAllConfigurationSteps
- assert(steps.size === types.size)
- assert(steps.map(_.getClass) === types)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
new file mode 100644
index 0000000000000..4e8c300543430
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.submit
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf}
+import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep}
+
+class KubernetesDriverBuilderSuite extends SparkFunSuite {
+
+ private val BASIC_STEP_TYPE = "basic"
+ private val CREDENTIALS_STEP_TYPE = "credentials"
+ private val SERVICE_STEP_TYPE = "service"
+ private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
+ private val SECRETS_STEP_TYPE = "mount-secrets"
+ private val JAVA_STEP_TYPE = "java-bindings"
+ private val PYSPARK_STEP_TYPE = "pyspark-bindings"
+ private val ENV_SECRETS_STEP_TYPE = "env-secrets"
+
+ private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep])
+
+ private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep])
+
+ private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep])
+
+ private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
+
+ private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
+
+ private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep])
+
+ private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep])
+
+ private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
+
+ private val builderUnderTest: KubernetesDriverBuilder =
+ new KubernetesDriverBuilder(
+ _ => basicFeatureStep,
+ _ => credentialsStep,
+ _ => serviceStep,
+ _ => secretsStep,
+ _ => envSecretsStep,
+ _ => localDirsStep,
+ _ => javaStep,
+ _ => pythonStep)
+
+ test("Apply fundamental steps all the time.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ Some(JavaMainAppResource("example.jar")),
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ JAVA_STEP_TYPE)
+ }
+
+ test("Apply secrets step if secrets are present.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ None,
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map("secret" -> "secretMountPath"),
+ Map("EnvName" -> "SecretName:secretKey"),
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ SECRETS_STEP_TYPE,
+ ENV_SECRETS_STEP_TYPE,
+ JAVA_STEP_TYPE)
+ }
+
+ test("Apply Java step if main resource is none.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ None,
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ JAVA_STEP_TYPE)
+ }
+
+ test("Apply Python step if main resource is python.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ Some(PythonMainAppResource("example.py")),
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ PYSPARK_STEP_TYPE)
+ }
+
+ private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*)
+ : Unit = {
+ assert(resolvedSpec.systemProperties.size === stepTypes.size)
+ stepTypes.foreach { stepType =>
+ assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType)
+ assert(resolvedSpec.driverKubernetesResources.containsSlice(
+ KubernetesFeaturesTestUtils.getSecretsForStepType(stepType)))
+ assert(resolvedSpec.systemProperties(stepType) === stepType)
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala
deleted file mode 100644
index b136f2c02ffba..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder}
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-
-class BasicDriverConfigurationStepSuite extends SparkFunSuite {
-
- private val APP_ID = "spark-app-id"
- private val RESOURCE_NAME_PREFIX = "spark"
- private val DRIVER_LABELS = Map("labelkey" -> "labelvalue")
- private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent"
- private val APP_NAME = "spark-test"
- private val MAIN_CLASS = "org.apache.spark.examples.SparkPi"
- private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"")
- private val CUSTOM_ANNOTATION_KEY = "customAnnotation"
- private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue"
- private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1"
- private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2"
-
- test("Set all possible configurations from the user.") {
- val sparkConf = new SparkConf()
- .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod")
- .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar")
- .set("spark.driver.cores", "2")
- .set(KUBERNETES_DRIVER_LIMIT_CORES, "4")
- .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M")
- .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L)
- .set(CONTAINER_IMAGE, "spark-driver:latest")
- .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE)
- .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1")
- .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2")
-
- val submissionStep = new BasicDriverConfigurationStep(
- APP_ID,
- RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- CONTAINER_IMAGE_PULL_POLICY,
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS,
- sparkConf)
- val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build()
- val baseDriverSpec = KubernetesDriverSpec(
- driverPod = basePod,
- driverContainer = new ContainerBuilder().build(),
- driverSparkConf = new SparkConf(false),
- otherKubernetesResources = Seq.empty[HasMetadata])
- val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec)
-
- assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME)
- assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest")
- assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY)
-
- assert(preparedDriverSpec.driverContainer.getEnv.size === 7)
- val envs = preparedDriverSpec.driverContainer
- .getEnv
- .asScala
- .map(env => (env.getName, env.getValue))
- .toMap
- assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar")
- assert(envs(ENV_DRIVER_MEMORY) === "256M")
- assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS)
- assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"")
- assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1")
- assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2")
-
- assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar =>
- envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) &&
- envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") &&
- envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP")))
-
- val resourceRequirements = preparedDriverSpec.driverContainer.getResources
- val requests = resourceRequirements.getRequests.asScala
- assert(requests("cpu").getAmount === "2")
- assert(requests("memory").getAmount === "256Mi")
- val limits = resourceRequirements.getLimits.asScala
- assert(limits("memory").getAmount === "456Mi")
- assert(limits("cpu").getAmount === "4")
-
- val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata
- assert(driverPodMetadata.getName === "spark-driver-pod")
- assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS)
- val expectedAnnotations = Map(
- CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE,
- SPARK_APP_NAME_ANNOTATION -> APP_NAME)
- assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations)
- assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never")
-
- val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap
- val expectedSparkConf = Map(
- KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod",
- "spark.app.id" -> APP_ID,
- KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX)
- assert(resolvedSparkConf === expectedSparkConf)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala
deleted file mode 100644
index 991b03cafb76c..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import java.io.File
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder}
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-
-class DependencyResolutionStepSuite extends SparkFunSuite {
-
- private val SPARK_JARS = Seq(
- "hdfs://localhost:9000/apps/jars/jar1.jar",
- "file:///home/user/apps/jars/jar2.jar",
- "local:///var/apps/jars/jar3.jar")
-
- private val SPARK_FILES = Seq(
- "file:///home/user/apps/files/file1.txt",
- "hdfs://localhost:9000/apps/files/file2.txt",
- "local:///var/apps/files/file3.txt")
-
- private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars"
- private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files"
-
- test("Added dependencies should be resolved in Spark configuration and environment") {
- val dependencyResolutionStep = new DependencyResolutionStep(
- SPARK_JARS,
- SPARK_FILES,
- JARS_DOWNLOAD_PATH,
- FILES_DOWNLOAD_PATH)
- val driverPod = new PodBuilder().build()
- val baseDriverSpec = KubernetesDriverSpec(
- driverPod = driverPod,
- driverContainer = new ContainerBuilder().build(),
- driverSparkConf = new SparkConf(false),
- otherKubernetesResources = Seq.empty[HasMetadata])
- val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec)
- assert(preparedDriverSpec.driverPod === driverPod)
- assert(preparedDriverSpec.otherKubernetesResources.isEmpty)
- val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet
- val expectedResolvedSparkJars = Set(
- "hdfs://localhost:9000/apps/jars/jar1.jar",
- s"$JARS_DOWNLOAD_PATH/jar2.jar",
- "/var/apps/jars/jar3.jar")
- assert(resolvedSparkJars === expectedResolvedSparkJars)
- val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet
- val expectedResolvedSparkFiles = Set(
- s"$FILES_DOWNLOAD_PATH/file1.txt",
- s"hdfs://localhost:9000/apps/files/file2.txt",
- s"/var/apps/files/file3.txt")
- assert(resolvedSparkFiles === expectedResolvedSparkFiles)
- val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala
- assert(driverEnv.size === 1)
- assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH)
- val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet
- val expectedResolvedDriverClasspath = Set(
- s"$JARS_DOWNLOAD_PATH/jar1.jar",
- s"$JARS_DOWNLOAD_PATH/jar2.jar",
- "/var/apps/jars/jar3.jar")
- assert(resolvedDriverClasspath === expectedResolvedDriverClasspath)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala
deleted file mode 100644
index 758871e2ba356..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import java.io.StringReader
-import java.util.Properties
-
-import scala.collection.JavaConverters._
-
-import com.google.common.collect.Maps
-import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder}
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec}
-import org.apache.spark.util.Utils
-
-class DriverInitContainerBootstrapStepSuite extends SparkFunSuite {
-
- private val CONFIG_MAP_NAME = "spark-init-config-map"
- private val CONFIG_MAP_KEY = "spark-init-config-map-key"
-
- test("The init container bootstrap step should use all of the init container steps") {
- val baseDriverSpec = KubernetesDriverSpec(
- driverPod = new PodBuilder().build(),
- driverContainer = new ContainerBuilder().build(),
- driverSparkConf = new SparkConf(false),
- otherKubernetesResources = Seq.empty[HasMetadata])
- val initContainerSteps = Seq(
- FirstTestInitContainerConfigurationStep,
- SecondTestInitContainerConfigurationStep)
- val bootstrapStep = new DriverInitContainerBootstrapStep(
- initContainerSteps,
- CONFIG_MAP_NAME,
- CONFIG_MAP_KEY)
-
- val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec)
-
- assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala ===
- FirstTestInitContainerConfigurationStep.additionalLabels)
- val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala
- assert(additionalDriverEnv.size === 1)
- assert(additionalDriverEnv.head.getName ===
- FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey)
- assert(additionalDriverEnv.head.getValue ===
- FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue)
-
- assert(preparedDriverSpec.otherKubernetesResources.size === 2)
- assert(preparedDriverSpec.otherKubernetesResources.contains(
- FirstTestInitContainerConfigurationStep.additionalKubernetesResource))
- assert(preparedDriverSpec.otherKubernetesResources.exists {
- case configMap: ConfigMap =>
- val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME
- val configMapData = configMap.getData.asScala
- val hasCorrectNumberOfEntries = configMapData.size == 1
- val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY)
- val initContainerProperties = new Properties()
- Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) {
- initContainerProperties.load(_)
- }
- val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala
- val expectedInitContainerProperties = Map(
- SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey ->
- SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue)
- val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties
- hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties
-
- case _ => false
- })
-
- val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers
- assert(initContainers.size() === 1)
- val initContainerEnv = initContainers.get(0).getEnv.asScala
- assert(initContainerEnv.size === 1)
- assert(initContainerEnv.head.getName ===
- SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey)
- assert(initContainerEnv.head.getValue ===
- SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue)
-
- val expectedSparkConf = Map(
- INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME,
- INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY,
- SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey ->
- SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue)
- assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf)
- }
-}
-
-private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep {
-
- val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue")
- val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY"
- val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE"
- val additionalKubernetesResource = new SecretBuilder()
- .withNewMetadata()
- .withName("test-secret")
- .endMetadata()
- .addToData("secret-key", "secret-value")
- .build()
-
- override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = {
- val driverPod = new PodBuilder(initContainerSpec.driverPod)
- .editOrNewMetadata()
- .addToLabels(additionalLabels.asJava)
- .endMetadata()
- .build()
- val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer)
- .addNewEnv()
- .withName(additionalMainContainerEnvKey)
- .withValue(additionalMainContainerEnvValue)
- .endEnv()
- .build()
- initContainerSpec.copy(
- driverPod = driverPod,
- driverContainer = mainContainer,
- dependentResources = initContainerSpec.dependentResources ++
- Seq(additionalKubernetesResource))
- }
-}
-
-private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep {
- val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY"
- val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE"
- val additionalInitContainerPropertyKey = "spark.initcontainer.testkey"
- val additionalInitContainerPropertyValue = "testvalue"
- val additionalDriverSparkConfKey = "spark.driver.testkey"
- val additionalDriverSparkConfValue = "spark.driver.testvalue"
-
- override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = {
- val initContainer = new ContainerBuilder(initContainerSpec.initContainer)
- .addNewEnv()
- .withName(additionalInitContainerEnvKey)
- .withValue(additionalInitContainerEnvValue)
- .endEnv()
- .build()
- val initContainerProperties = initContainerSpec.properties ++
- Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue)
- val driverSparkConf = initContainerSpec.driverSparkConf ++
- Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue)
- initContainerSpec.copy(
- initContainer = initContainer,
- properties = initContainerProperties,
- driverSparkConf = driverSparkConf)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala
deleted file mode 100644
index 78c8c3ba1afbd..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala
+++ /dev/null
@@ -1,180 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.Service
-import org.mockito.{Mock, MockitoAnnotations}
-import org.mockito.Mockito.when
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-import org.apache.spark.util.Clock
-
-class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter {
-
- private val SHORT_RESOURCE_NAME_PREFIX =
- "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH -
- DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length)
-
- private val LONG_RESOURCE_NAME_PREFIX =
- "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH -
- DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1)
- private val DRIVER_LABELS = Map(
- "label1key" -> "label1value",
- "label2key" -> "label2value")
-
- @Mock
- private var clock: Clock = _
-
- private var sparkConf: SparkConf = _
-
- before {
- MockitoAnnotations.initMocks(this)
- sparkConf = new SparkConf(false)
- }
-
- test("Headless service has a port for the driver RPC and the block manager.") {
- val configurationStep = new DriverServiceBootstrapStep(
- SHORT_RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- sparkConf
- .set("spark.driver.port", "9000")
- .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080),
- clock)
- val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone())
- val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec)
- assert(resolvedDriverSpec.otherKubernetesResources.size === 1)
- assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service])
- val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]
- verifyService(
- 9000,
- 8080,
- s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}",
- driverService)
- }
-
- test("Hostname and ports are set according to the service name.") {
- val configurationStep = new DriverServiceBootstrapStep(
- SHORT_RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- sparkConf
- .set("spark.driver.port", "9000")
- .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080)
- .set(KUBERNETES_NAMESPACE, "my-namespace"),
- clock)
- val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone())
- val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec)
- val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX +
- DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX
- val expectedHostName = s"$expectedServiceName.my-namespace.svc"
- verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName)
- }
-
- test("Ports should resolve to defaults in SparkConf and in the service.") {
- val configurationStep = new DriverServiceBootstrapStep(
- SHORT_RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- sparkConf,
- clock)
- val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone())
- val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec)
- verifyService(
- DEFAULT_DRIVER_PORT,
- DEFAULT_BLOCKMANAGER_PORT,
- s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}",
- resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service])
- assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") ===
- DEFAULT_DRIVER_PORT.toString)
- assert(resolvedDriverSpec.driverSparkConf.get(
- org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT)
- }
-
- test("Long prefixes should switch to using a generated name.") {
- val configurationStep = new DriverServiceBootstrapStep(
- LONG_RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"),
- clock)
- when(clock.getTimeMillis()).thenReturn(10000)
- val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone())
- val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec)
- val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]
- val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}"
- assert(driverService.getMetadata.getName === expectedServiceName)
- val expectedHostName = s"$expectedServiceName.my-namespace.svc"
- verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName)
- }
-
- test("Disallow bind address and driver host to be set explicitly.") {
- val configurationStep = new DriverServiceBootstrapStep(
- LONG_RESOURCE_NAME_PREFIX,
- DRIVER_LABELS,
- sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"),
- clock)
- try {
- configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf))
- fail("The driver bind address should not be allowed.")
- } catch {
- case e: Throwable =>
- assert(e.getMessage ===
- s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" +
- " not supported in Kubernetes mode, as the driver's bind address is managed" +
- " and set to the driver pod's IP address.")
- }
- sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS)
- sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host")
- try {
- configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf))
- fail("The driver host address should not be allowed.")
- } catch {
- case e: Throwable =>
- assert(e.getMessage ===
- s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" +
- " not supported in Kubernetes mode, as the driver's hostname will be managed via" +
- " a Kubernetes service.")
- }
- }
-
- private def verifyService(
- driverPort: Int,
- blockManagerPort: Int,
- expectedServiceName: String,
- service: Service): Unit = {
- assert(service.getMetadata.getName === expectedServiceName)
- assert(service.getSpec.getClusterIP === "None")
- assert(service.getSpec.getSelector.asScala === DRIVER_LABELS)
- assert(service.getSpec.getPorts.size() === 2)
- val driverServicePorts = service.getSpec.getPorts.asScala
- assert(driverServicePorts.head.getName === DRIVER_PORT_NAME)
- assert(driverServicePorts.head.getPort.intValue() === driverPort)
- assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort)
- assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME)
- assert(driverServicePorts(1).getPort.intValue() === blockManagerPort)
- assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort)
- }
-
- private def verifySparkConfHostNames(
- driverSparkConf: SparkConf, expectedHostName: String): Unit = {
- assert(driverSparkConf.get(
- org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala
deleted file mode 100644
index 4553f9f6b1d45..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model._
-import org.mockito.{Mock, MockitoAnnotations}
-import org.mockito.Matchers.any
-import org.mockito.Mockito.when
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer}
-import org.apache.spark.deploy.k8s.Config._
-
-class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter {
-
- private val SPARK_JARS = Seq(
- "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar")
- private val SPARK_FILES = Seq(
- "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt")
- private val JARS_DOWNLOAD_PATH = "/var/data/jars"
- private val FILES_DOWNLOAD_PATH = "/var/data/files"
- private val POD_LABEL = Map("bootstrap" -> "true")
- private val INIT_CONTAINER_NAME = "init-container"
- private val DRIVER_CONTAINER_NAME = "driver-container"
-
- @Mock
- private var podAndInitContainerBootstrap : InitContainerBootstrap = _
-
- before {
- MockitoAnnotations.initMocks(this)
- when(podAndInitContainerBootstrap.bootstrapInitContainer(
- any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] {
- override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = {
- val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer])
- pod.copy(
- pod = new PodBuilder(pod.pod)
- .withNewMetadata()
- .addToLabels("bootstrap", "true")
- .endMetadata()
- .withNewSpec().endSpec()
- .build(),
- initContainer = new ContainerBuilder()
- .withName(INIT_CONTAINER_NAME)
- .build(),
- mainContainer = new ContainerBuilder()
- .withName(DRIVER_CONTAINER_NAME)
- .build()
- )}})
- }
-
- test("additionalDriverSparkConf with mix of remote files and jars") {
- val baseInitStep = new BasicInitContainerConfigurationStep(
- SPARK_JARS,
- SPARK_FILES,
- JARS_DOWNLOAD_PATH,
- FILES_DOWNLOAD_PATH,
- podAndInitContainerBootstrap)
- val expectedDriverSparkConf = Map(
- JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH,
- FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH,
- INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar",
- INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt")
- val initContainerSpec = InitContainerSpec(
- Map.empty[String, String],
- Map.empty[String, String],
- new Container(),
- new Container(),
- new Pod,
- Seq.empty[HasMetadata])
- val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec)
- assert(expectedDriverSparkConf === returnContainerSpec.properties)
- assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME)
- assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME)
- assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala
deleted file mode 100644
index 09b42e4484d86..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-
-class InitContainerConfigOrchestratorSuite extends SparkFunSuite {
-
- private val DOCKER_IMAGE = "init-container"
- private val SPARK_JARS = Seq(
- "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar")
- private val SPARK_FILES = Seq(
- "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt")
- private val JARS_DOWNLOAD_PATH = "/var/data/jars"
- private val FILES_DOWNLOAD_PATH = "/var/data/files"
- private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent"
- private val CUSTOM_LABEL_KEY = "customLabel"
- private val CUSTOM_LABEL_VALUE = "customLabelValue"
- private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map"
- private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key"
- private val SECRET_FOO = "foo"
- private val SECRET_BAR = "bar"
- private val SECRET_MOUNT_PATH = "/etc/secrets/init-container"
-
- test("including basic configuration step") {
- val sparkConf = new SparkConf(true)
- .set(CONTAINER_IMAGE, DOCKER_IMAGE)
- .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE)
-
- val orchestrator = new InitContainerConfigOrchestrator(
- SPARK_JARS.take(1),
- SPARK_FILES,
- JARS_DOWNLOAD_PATH,
- FILES_DOWNLOAD_PATH,
- DOCKER_IMAGE_PULL_POLICY,
- INIT_CONTAINER_CONFIG_MAP_NAME,
- INIT_CONTAINER_CONFIG_MAP_KEY,
- sparkConf)
- val initSteps = orchestrator.getAllConfigurationSteps
- assert(initSteps.lengthCompare(1) == 0)
- assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep])
- }
-
- test("including step to mount user-specified secrets") {
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, DOCKER_IMAGE)
- .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH)
- .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH)
-
- val orchestrator = new InitContainerConfigOrchestrator(
- SPARK_JARS.take(1),
- SPARK_FILES,
- JARS_DOWNLOAD_PATH,
- FILES_DOWNLOAD_PATH,
- DOCKER_IMAGE_PULL_POLICY,
- INIT_CONTAINER_CONFIG_MAP_NAME,
- INIT_CONTAINER_CONFIG_MAP_KEY,
- sparkConf)
- val initSteps = orchestrator.getAllConfigurationSteps
- assert(initSteps.length === 2)
- assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep])
- assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep])
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala
deleted file mode 100644
index 7ac0bde80dfe6..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
-
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder}
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils}
-
-class InitContainerMountSecretsStepSuite extends SparkFunSuite {
-
- private val SECRET_FOO = "foo"
- private val SECRET_BAR = "bar"
- private val SECRET_MOUNT_PATH = "/etc/secrets/init-container"
-
- test("mounts all given secrets") {
- val baseInitContainerSpec = InitContainerSpec(
- Map.empty,
- Map.empty,
- new ContainerBuilder().build(),
- new ContainerBuilder().build(),
- new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(),
- Seq.empty)
- val secretNamesToMountPaths = Map(
- SECRET_FOO -> SECRET_MOUNT_PATH,
- SECRET_BAR -> SECRET_MOUNT_PATH)
-
- val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths)
- val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap)
- val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer(
- baseInitContainerSpec)
- val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer
-
- Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName =>
- assert(SecretVolumeUtils.containerHasVolume(
- initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH)))
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala
new file mode 100644
index 0000000000000..f7721e6fd6388
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.Pod
+import scala.collection.mutable
+
+class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore {
+
+ private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot]
+ private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit]
+
+ private var currentSnapshot = ExecutorPodsSnapshot()
+
+ override def addSubscriber
+ (processBatchIntervalMillis: Long)
+ (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = {
+ subscribers += onNewSnapshots
+ }
+
+ override def stop(): Unit = {}
+
+ def notifySubscribers(): Unit = {
+ subscribers.foreach(_(snapshotsBuffer))
+ snapshotsBuffer.clear()
+ }
+
+ override def updatePod(updatedPod: Pod): Unit = {
+ currentSnapshot = currentSnapshot.withUpdate(updatedPod)
+ snapshotsBuffer += currentSnapshot
+ }
+
+ override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = {
+ currentSnapshot = ExecutorPodsSnapshot(newSnapshot)
+ snapshotsBuffer += currentSnapshot
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala
new file mode 100644
index 0000000000000..c6b667ed85e8c
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder}
+
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.SparkPod
+
+object ExecutorLifecycleTestUtils {
+
+ val TEST_SPARK_APP_ID = "spark-app-id"
+
+ def failedExecutorWithoutDeletion(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewStatus()
+ .withPhase("failed")
+ .addNewContainerStatus()
+ .withName("spark-executor")
+ .withImage("k8s-spark")
+ .withNewState()
+ .withNewTerminated()
+ .withMessage("Failed")
+ .withExitCode(1)
+ .endTerminated()
+ .endState()
+ .endContainerStatus()
+ .addNewContainerStatus()
+ .withName("spark-executor-sidecar")
+ .withImage("k8s-spark-sidecar")
+ .withNewState()
+ .withNewTerminated()
+ .withMessage("Failed")
+ .withExitCode(1)
+ .endTerminated()
+ .endState()
+ .endContainerStatus()
+ .withMessage("Executor failed.")
+ .withReason("Executor failed because of a thrown error.")
+ .endStatus()
+ .build()
+ }
+
+ def pendingExecutor(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewStatus()
+ .withPhase("pending")
+ .endStatus()
+ .build()
+ }
+
+ def runningExecutor(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewStatus()
+ .withPhase("running")
+ .endStatus()
+ .build()
+ }
+
+ def succeededExecutor(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewStatus()
+ .withPhase("succeeded")
+ .endStatus()
+ .build()
+ }
+
+ def deletedExecutor(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewMetadata()
+ .withNewDeletionTimestamp("523012521")
+ .endMetadata()
+ .build()
+ }
+
+ def unknownExecutor(executorId: Long): Pod = {
+ new PodBuilder(podWithAttachedContainerForId(executorId))
+ .editOrNewStatus()
+ .withPhase("unknown")
+ .endStatus()
+ .build()
+ }
+
+ def podWithAttachedContainerForId(executorId: Long): Pod = {
+ val sparkPod = executorPodWithId(executorId)
+ val podWithAttachedContainer = new PodBuilder(sparkPod.pod)
+ .editOrNewSpec()
+ .addToContainers(sparkPod.container)
+ .endSpec()
+ .build()
+ podWithAttachedContainer
+ }
+
+ def executorPodWithId(executorId: Long): SparkPod = {
+ val pod = new PodBuilder()
+ .withNewMetadata()
+ .withName(s"spark-executor-$executorId")
+ .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)
+ .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString)
+ .endMetadata()
+ .build()
+ val container = new ContainerBuilder()
+ .withName("spark-executor")
+ .withImage("k8s-spark")
+ .build()
+ SparkPod(pod, container)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
deleted file mode 100644
index a3c615be031d2..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.scheduler.cluster.k8s
-
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model._
-import org.mockito.{AdditionalAnswers, MockitoAnnotations}
-import org.mockito.Matchers.any
-import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach}
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils}
-import org.apache.spark.deploy.k8s.Config._
-import org.apache.spark.deploy.k8s.Constants._
-
-class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach {
-
- private val driverPodName: String = "driver-pod"
- private val driverPodUid: String = "driver-uid"
- private val executorPrefix: String = "base"
- private val executorImage: String = "executor-image"
- private val driverPod = new PodBuilder()
- .withNewMetadata()
- .withName(driverPodName)
- .withUid(driverPodUid)
- .endMetadata()
- .withNewSpec()
- .withNodeName("some-node")
- .endSpec()
- .withNewStatus()
- .withHostIP("192.168.99.100")
- .endStatus()
- .build()
- private var baseConf: SparkConf = _
-
- before {
- MockitoAnnotations.initMocks(this)
- baseConf = new SparkConf()
- .set(KUBERNETES_DRIVER_POD_NAME, driverPodName)
- .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix)
- .set(CONTAINER_IMAGE, executorImage)
- }
-
- test("basic executor pod has reasonable defaults") {
- val factory = new ExecutorPodFactory(baseConf, None, None, None)
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
-
- // The executor pod name and default labels.
- assert(executor.getMetadata.getName === s"$executorPrefix-exec-1")
- assert(executor.getMetadata.getLabels.size() === 3)
- assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1")
-
- // There is exactly 1 container with no volume mounts and default memory limits.
- // Default memory limit is 1024M + 384M (minimum overhead constant).
- assert(executor.getSpec.getContainers.size() === 1)
- assert(executor.getSpec.getContainers.get(0).getImage === executorImage)
- assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty)
- assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1)
- assert(executor.getSpec.getContainers.get(0).getResources
- .getLimits.get("memory").getAmount === "1408Mi")
-
- // The pod has no node selector, volumes.
- assert(executor.getSpec.getNodeSelector.isEmpty)
- assert(executor.getSpec.getVolumes.isEmpty)
-
- checkEnv(executor, Map())
- checkOwnerReferences(executor, driverPodUid)
- }
-
- test("executor pod hostnames get truncated to 63 characters") {
- val conf = baseConf.clone()
- conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX,
- "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple")
-
- val factory = new ExecutorPodFactory(conf, None, None, None)
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
-
- assert(executor.getSpec.getHostname.length === 63)
- }
-
- test("classpath and extra java options get translated into environment variables") {
- val conf = baseConf.clone()
- conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar")
- conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz")
-
- val factory = new ExecutorPodFactory(conf, None, None, None)
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]())
-
- checkEnv(executor,
- Map("SPARK_JAVA_OPT_0" -> "foo=bar",
- ENV_CLASSPATH -> "bar=baz",
- "qux" -> "quux"))
- checkOwnerReferences(executor, driverPodUid)
- }
-
- test("executor secrets get mounted") {
- val conf = baseConf.clone()
-
- val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1"))
- val factory = new ExecutorPodFactory(
- conf,
- Some(secretsBootstrap),
- None,
- None)
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
-
- assert(executor.getSpec.getContainers.size() === 1)
- assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1)
- assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName
- === "secret1-volume")
- assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0)
- .getMountPath === "/var/secret1")
-
- // check volume mounted.
- assert(executor.getSpec.getVolumes.size() === 1)
- assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1")
-
- checkOwnerReferences(executor, driverPodUid)
- }
-
- test("init-container bootstrap step adds an init container") {
- val conf = baseConf.clone()
- val initContainerBootstrap = mock(classOf[InitContainerBootstrap])
- when(initContainerBootstrap.bootstrapInitContainer(
- any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
-
- val factory = new ExecutorPodFactory(
- conf,
- None,
- Some(initContainerBootstrap),
- None)
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
-
- assert(executor.getSpec.getInitContainers.size() === 1)
- checkOwnerReferences(executor, driverPodUid)
- }
-
- test("init-container with secrets mount bootstrap") {
- val conf = baseConf.clone()
- val initContainerBootstrap = mock(classOf[InitContainerBootstrap])
- when(initContainerBootstrap.bootstrapInitContainer(
- any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
- val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1"))
-
- val factory = new ExecutorPodFactory(
- conf,
- Some(secretsBootstrap),
- Some(initContainerBootstrap),
- Some(secretsBootstrap))
- val executor = factory.createExecutorPod(
- "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]())
-
- assert(executor.getSpec.getVolumes.size() === 1)
- assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume"))
- assert(SecretVolumeUtils.containerHasVolume(
- executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1"))
- assert(executor.getSpec.getInitContainers.size() === 1)
- assert(SecretVolumeUtils.containerHasVolume(
- executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1"))
-
- checkOwnerReferences(executor, driverPodUid)
- }
-
- // There is always exactly one controller reference, and it points to the driver pod.
- private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = {
- assert(executor.getMetadata.getOwnerReferences.size() === 1)
- assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid)
- assert(executor.getMetadata.getOwnerReferences.get(0).getController === true)
- }
-
- // Check that the expected environment variables are present.
- private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = {
- val defaultEnvs = Map(
- ENV_EXECUTOR_ID -> "1",
- ENV_DRIVER_URL -> "dummy",
- ENV_EXECUTOR_CORES -> "1",
- ENV_EXECUTOR_MEMORY -> "1g",
- ENV_APPLICATION_ID -> "dummy",
- ENV_EXECUTOR_POD_IP -> null,
- ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars
-
- assert(executor.getSpec.getContainers.size() === 1)
- assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size)
- val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map {
- x => (x.getName, x.getValue)
- }.toMap
- assert(defaultEnvs === mapEnvs)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala
new file mode 100644
index 0000000000000..0c19f5946b75f
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder}
+import io.fabric8.kubernetes.client.KubernetesClient
+import io.fabric8.kubernetes.client.dsl.PodResource
+import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations}
+import org.mockito.Matchers.any
+import org.mockito.Mockito.{never, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._
+import org.apache.spark.util.ManualClock
+
+class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private val driverPodName = "driver"
+
+ private val driverPod = new PodBuilder()
+ .withNewMetadata()
+ .withName(driverPodName)
+ .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)
+ .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE)
+ .withUid("driver-pod-uid")
+ .endMetadata()
+ .build()
+
+ private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName)
+
+ private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE)
+ private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY)
+ private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L)
+
+ private var waitForExecutorPodsClock: ManualClock = _
+
+ @Mock
+ private var kubernetesClient: KubernetesClient = _
+
+ @Mock
+ private var podOperations: PODS = _
+
+ @Mock
+ private var labeledPods: LABELED_PODS = _
+
+ @Mock
+ private var driverPodOperations: PodResource[Pod, DoneablePod] = _
+
+ @Mock
+ private var executorBuilder: KubernetesExecutorBuilder = _
+
+ private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _
+
+ private var podsAllocatorUnderTest: ExecutorPodsAllocator = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations)
+ when(driverPodOperations.get).thenReturn(driverPod)
+ when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields()))
+ .thenAnswer(executorPodAnswer())
+ snapshotsStore = new DeterministicExecutorPodsSnapshotsStore()
+ waitForExecutorPodsClock = new ManualClock(0L)
+ podsAllocatorUnderTest = new ExecutorPodsAllocator(
+ conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock)
+ podsAllocatorUnderTest.start(TEST_SPARK_APP_ID)
+ }
+
+ test("Initially request executors in batches. Do not request another batch if the" +
+ " first has not finished.") {
+ podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1)
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ snapshotsStore.notifySubscribers()
+ for (nextId <- 1 to podAllocationSize) {
+ verify(podOperations).create(podWithAttachedContainerForId(nextId))
+ }
+ verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1))
+ }
+
+ test("Request executors in batches. Allow another batch to be requested if" +
+ " all pending executors start running.") {
+ podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1)
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ snapshotsStore.notifySubscribers()
+ for (execId <- 1 until podAllocationSize) {
+ snapshotsStore.updatePod(runningExecutor(execId))
+ }
+ snapshotsStore.notifySubscribers()
+ verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1))
+ snapshotsStore.updatePod(runningExecutor(podAllocationSize))
+ snapshotsStore.notifySubscribers()
+ verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1))
+ snapshotsStore.updatePod(runningExecutor(podAllocationSize))
+ snapshotsStore.notifySubscribers()
+ verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod]))
+ }
+
+ test("When a current batch reaches error states immediately, re-request" +
+ " them on the next batch.") {
+ podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize)
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ snapshotsStore.notifySubscribers()
+ for (execId <- 1 until podAllocationSize) {
+ snapshotsStore.updatePod(runningExecutor(execId))
+ }
+ val failedPod = failedExecutorWithoutDeletion(podAllocationSize)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.notifySubscribers()
+ verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1))
+ }
+
+ test("When an executor is requested but the API does not report it in a reasonable time, retry" +
+ " requesting that executor.") {
+ podsAllocatorUnderTest.setTotalExpectedExecutors(1)
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ snapshotsStore.notifySubscribers()
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ waitForExecutorPodsClock.setTime(podCreationTimeout + 1)
+ when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods)
+ snapshotsStore.notifySubscribers()
+ verify(labeledPods).delete()
+ verify(podOperations).create(podWithAttachedContainerForId(2))
+ }
+
+ private def executorPodAnswer(): Answer[SparkPod] = {
+ new Answer[SparkPod] {
+ override def answer(invocation: InvocationOnMock): SparkPod = {
+ val k8sConf = invocation.getArgumentAt(
+ 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]])
+ executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt)
+ }
+ }
+ }
+
+ private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] =
+ Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] {
+ override def matches(argument: scala.Any): Boolean = {
+ if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) {
+ false
+ } else {
+ val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]
+ val executorSpecificConf = k8sConf.roleSpecificConf
+ val expectedK8sConf = KubernetesConf.createExecutorConf(
+ conf,
+ executorSpecificConf.executorId,
+ TEST_SPARK_APP_ID,
+ driverPod)
+ k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap &&
+ // Since KubernetesConf.createExecutorConf clones the SparkConf object, force
+ // deep equality comparison for the SparkConf object and use object equality
+ // comparison on all other fields.
+ k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf)
+ }
+ }
+ })
+
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala
new file mode 100644
index 0000000000000..562ace9f49d4d
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import com.google.common.cache.CacheBuilder
+import io.fabric8.kubernetes.api.model.{DoneablePod, Pod}
+import io.fabric8.kubernetes.client.KubernetesClient
+import io.fabric8.kubernetes.client.dsl.PodResource
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Matchers.any
+import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.BeforeAndAfter
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
+import org.apache.spark.scheduler.ExecutorExited
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._
+
+class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _
+
+ @Mock
+ private var kubernetesClient: KubernetesClient = _
+
+ @Mock
+ private var podOperations: PODS = _
+
+ @Mock
+ private var executorBuilder: KubernetesExecutorBuilder = _
+
+ @Mock
+ private var schedulerBackend: KubernetesClusterSchedulerBackend = _
+
+ private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _
+ private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long]
+ snapshotsStore = new DeterministicExecutorPodsSnapshotsStore()
+ namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]]
+ when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String])
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer())
+ eventHandlerUnderTest = new ExecutorPodsLifecycleManager(
+ new SparkConf(),
+ executorBuilder,
+ kubernetesClient,
+ snapshotsStore,
+ removedExecutorsCache)
+ eventHandlerUnderTest.start(schedulerBackend)
+ }
+
+ test("When an executor reaches error states immediately, remove from the scheduler backend.") {
+ val failedPod = failedExecutorWithoutDeletion(1)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.notifySubscribers()
+ val msg = exitReasonMessage(1, failedPod)
+ val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg)
+ verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason)
+ verify(namedExecutorPods(failedPod.getMetadata.getName)).delete()
+ }
+
+ test("Don't remove executors twice from Spark but remove from K8s repeatedly.") {
+ val failedPod = failedExecutorWithoutDeletion(1)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.notifySubscribers()
+ val msg = exitReasonMessage(1, failedPod)
+ val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg)
+ verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason)
+ verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete()
+ }
+
+ test("When the scheduler backend lists executor ids that aren't present in the cluster," +
+ " remove those executors from Spark.") {
+ when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1"))
+ val msg = s"The executor with ID 1 was not found in the cluster but we didn't" +
+ s" get a reason why. Marking the executor as failed. The executor may have been" +
+ s" deleted but the driver missed the deletion event."
+ val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg)
+ snapshotsStore.replaceSnapshot(Seq.empty[Pod])
+ snapshotsStore.notifySubscribers()
+ verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason)
+ }
+
+ private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = {
+ s"""
+ |The executor with id $failedExecutorId exited with exit code 1.
+ |The API gave the following brief reason: ${failedPod.getStatus.getReason}
+ |The API gave the following message: ${failedPod.getStatus.getMessage}
+ |The API gave the following container statuses:
+ |
+ |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")}
+ """.stripMargin
+ }
+
+ private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = {
+ new Answer[PodResource[Pod, DoneablePod]] {
+ override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = {
+ val podName = invocation.getArgumentAt(0, classOf[String])
+ namedExecutorPods.getOrElseUpdate(
+ podName, mock(classOf[PodResource[Pod, DoneablePod]]))
+ }
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala
new file mode 100644
index 0000000000000..1b26d6af296a5
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent.TimeUnit
+
+import io.fabric8.kubernetes.api.model.PodListBuilder
+import io.fabric8.kubernetes.client.KubernetesClient
+import org.jmock.lib.concurrent.DeterministicScheduler
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Mockito.{verify, when}
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._
+
+class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private val sparkConf = new SparkConf
+
+ private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL)
+
+ @Mock
+ private var kubernetesClient: KubernetesClient = _
+
+ @Mock
+ private var podOperations: PODS = _
+
+ @Mock
+ private var appIdLabeledPods: LABELED_PODS = _
+
+ @Mock
+ private var executorRoleLabeledPods: LABELED_PODS = _
+
+ @Mock
+ private var eventQueue: ExecutorPodsSnapshotsStore = _
+
+ private var pollingExecutor: DeterministicScheduler = _
+ private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ pollingExecutor = new DeterministicScheduler()
+ pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource(
+ sparkConf,
+ kubernetesClient,
+ eventQueue,
+ pollingExecutor)
+ pollingSourceUnderTest.start(TEST_SPARK_APP_ID)
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID))
+ .thenReturn(appIdLabeledPods)
+ when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE))
+ .thenReturn(executorRoleLabeledPods)
+ }
+
+ test("Items returned by the API should be pushed to the event queue") {
+ when(executorRoleLabeledPods.list())
+ .thenReturn(new PodListBuilder()
+ .addToItems(
+ runningExecutor(1),
+ runningExecutor(2))
+ .build())
+ pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS)
+ verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2)))
+
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala
new file mode 100644
index 0000000000000..70e19c904eddb
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._
+
+class ExecutorPodsSnapshotSuite extends SparkFunSuite {
+
+ test("States are interpreted correctly from pod metadata.") {
+ val pods = Seq(
+ pendingExecutor(0),
+ runningExecutor(1),
+ succeededExecutor(2),
+ failedExecutorWithoutDeletion(3),
+ deletedExecutor(4),
+ unknownExecutor(5))
+ val snapshot = ExecutorPodsSnapshot(pods)
+ assert(snapshot.executorPods ===
+ Map(
+ 0L -> PodPending(pods(0)),
+ 1L -> PodRunning(pods(1)),
+ 2L -> PodSucceeded(pods(2)),
+ 3L -> PodFailed(pods(3)),
+ 4L -> PodDeleted(pods(4)),
+ 5L -> PodUnknown(pods(5))))
+ }
+
+ test("Updates add new pods for non-matching ids and edit existing pods for matching ids") {
+ val originalPods = Seq(
+ pendingExecutor(0),
+ runningExecutor(1))
+ val originalSnapshot = ExecutorPodsSnapshot(originalPods)
+ val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1))
+ assert(snapshotWithUpdatedPod.executorPods ===
+ Map(
+ 0L -> PodPending(originalPods(0)),
+ 1L -> PodSucceeded(succeededExecutor(1))))
+ val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2))
+ assert(snapshotWithNewPod.executorPods ===
+ Map(
+ 0L -> PodPending(originalPods(0)),
+ 1L -> PodSucceeded(succeededExecutor(1)),
+ 2L -> PodPending(pendingExecutor(2))))
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala
new file mode 100644
index 0000000000000..cf54b3c4eb329
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
+
+import io.fabric8.kubernetes.api.model.{Pod, PodBuilder}
+import org.jmock.lib.concurrent.DeterministicScheduler
+import org.scalatest.BeforeAndAfter
+import scala.collection.mutable
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.deploy.k8s.Constants._
+
+class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter {
+
+ private var eventBufferScheduler: DeterministicScheduler = _
+ private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _
+
+ before {
+ eventBufferScheduler = new DeterministicScheduler()
+ eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler)
+ }
+
+ test("Subscribers get notified of events periodically.") {
+ val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot]
+ val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot]
+ eventQueueUnderTest.addSubscriber(1000) {
+ receivedSnapshots1 ++= _
+ }
+ eventQueueUnderTest.addSubscriber(2000) {
+ receivedSnapshots2 ++= _
+ }
+
+ eventBufferScheduler.runUntilIdle()
+ assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot()))
+ assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot()))
+
+ pushPodWithIndex(1)
+ // Force time to move forward so that the buffer is emitted, scheduling the
+ // processing task on the subscription executor...
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+ // ... then actually execute the subscribers.
+
+ assert(receivedSnapshots1 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1)))))
+ assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot()))
+
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+
+ // Don't repeat snapshots
+ assert(receivedSnapshots1 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1)))))
+ assert(receivedSnapshots2 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1)))))
+ pushPodWithIndex(2)
+ pushPodWithIndex(3)
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+
+ assert(receivedSnapshots1 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1))),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3)))))
+ assert(receivedSnapshots2 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1)))))
+
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+ assert(receivedSnapshots1 === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1))),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3)))))
+ assert(receivedSnapshots1 === receivedSnapshots2)
+ }
+
+ test("Even without sending events, initially receive an empty buffer.") {
+ val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null)
+ eventQueueUnderTest.addSubscriber(1000) {
+ receivedInitialSnapshot.set
+ }
+ assert(receivedInitialSnapshot.get == null)
+ eventBufferScheduler.runUntilIdle()
+ assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot()))
+ }
+
+ test("Replacing the snapshot passes the new snapshot to subscribers.") {
+ val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot]
+ eventQueueUnderTest.addSubscriber(1000) {
+ receivedSnapshots ++= _
+ }
+ eventQueueUnderTest.updatePod(podWithIndex(1))
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+ assert(receivedSnapshots === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1)))))
+ eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2)))
+ eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS)
+ assert(receivedSnapshots === Seq(
+ ExecutorPodsSnapshot(),
+ ExecutorPodsSnapshot(Seq(podWithIndex(1))),
+ ExecutorPodsSnapshot(Seq(podWithIndex(2)))))
+ }
+
+ private def pushPodWithIndex(index: Int): Unit =
+ eventQueueUnderTest.updatePod(podWithIndex(index))
+
+ private def podWithIndex(index: Int): Pod =
+ new PodBuilder()
+ .editOrNewMetadata()
+ .withName(s"pod-$index")
+ .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString)
+ .endMetadata()
+ .editOrNewStatus()
+ .withPhase("running")
+ .endStatus()
+ .build()
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala
new file mode 100644
index 0000000000000..ac1968b4ff810
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.Pod
+import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher}
+import io.fabric8.kubernetes.client.Watcher.Action
+import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations}
+import org.mockito.Mockito.{verify, when}
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._
+
+class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter {
+
+ @Mock
+ private var eventQueue: ExecutorPodsSnapshotsStore = _
+
+ @Mock
+ private var kubernetesClient: KubernetesClient = _
+
+ @Mock
+ private var podOperations: PODS = _
+
+ @Mock
+ private var appIdLabeledPods: LABELED_PODS = _
+
+ @Mock
+ private var executorRoleLabeledPods: LABELED_PODS = _
+
+ @Mock
+ private var watchConnection: Watch = _
+
+ private var watch: ArgumentCaptor[Watcher[Pod]] = _
+
+ private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _
+
+ before {
+ MockitoAnnotations.initMocks(this)
+ watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]])
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID))
+ .thenReturn(appIdLabeledPods)
+ when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE))
+ .thenReturn(executorRoleLabeledPods)
+ when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection)
+ watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource(
+ eventQueue, kubernetesClient)
+ watchSourceUnderTest.start(TEST_SPARK_APP_ID)
+ }
+
+ test("Watch events should be pushed to the snapshots store as snapshot updates.") {
+ watch.getValue.eventReceived(Action.ADDED, runningExecutor(1))
+ watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2))
+ verify(eventQueue).updatePod(runningExecutor(1))
+ verify(eventQueue).updatePod(runningExecutor(2))
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
index b2f26f205a329..52e7a12dbaf06 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
@@ -16,85 +16,36 @@
*/
package org.apache.spark.scheduler.cluster.k8s
-import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit}
-
-import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList}
-import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher}
-import io.fabric8.kubernetes.client.Watcher.Action
-import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource}
-import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations}
-import org.mockito.Matchers.{any, eq => mockitoEq}
-import org.mockito.Mockito.{doNothing, never, times, verify, when}
+import io.fabric8.kubernetes.client.KubernetesClient
+import org.jmock.lib.concurrent.DeterministicScheduler
+import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations}
+import org.mockito.Matchers.{eq => mockitoEq}
+import org.mockito.Mockito.{never, verify, when}
import org.scalatest.BeforeAndAfter
-import org.scalatest.mockito.MockitoSugar._
-import scala.collection.JavaConverters._
-import scala.concurrent.Future
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.rpc._
-import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor}
+import org.apache.spark.deploy.k8s.Fabric8Aliases._
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID
class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter {
- private val APP_ID = "test-spark-app"
- private val DRIVER_POD_NAME = "spark-driver-pod"
- private val NAMESPACE = "test-namespace"
- private val SPARK_DRIVER_HOST = "localhost"
- private val SPARK_DRIVER_PORT = 7077
- private val POD_ALLOCATION_INTERVAL = "1m"
- private val DRIVER_URL = RpcEndpointAddress(
- SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
- private val FIRST_EXECUTOR_POD = new PodBuilder()
- .withNewMetadata()
- .withName("pod1")
- .endMetadata()
- .withNewSpec()
- .withNodeName("node1")
- .endSpec()
- .withNewStatus()
- .withHostIP("192.168.99.100")
- .endStatus()
- .build()
- private val SECOND_EXECUTOR_POD = new PodBuilder()
- .withNewMetadata()
- .withName("pod2")
- .endMetadata()
- .withNewSpec()
- .withNodeName("node2")
- .endSpec()
- .withNewStatus()
- .withHostIP("192.168.99.101")
- .endStatus()
- .build()
-
- private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
- private type LABELED_PODS = FilterWatchListDeletable[
- Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]]
- private type IN_NAMESPACE_PODS = NonNamespaceOperation[
- Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]
-
- @Mock
- private var sparkContext: SparkContext = _
-
- @Mock
- private var listenerBus: LiveListenerBus = _
-
- @Mock
- private var taskSchedulerImpl: TaskSchedulerImpl = _
+ private val requestExecutorsService = new DeterministicScheduler()
+ private val sparkConf = new SparkConf(false)
+ .set("spark.executor.instances", "3")
@Mock
- private var allocatorExecutor: ScheduledExecutorService = _
+ private var sc: SparkContext = _
@Mock
- private var requestExecutorsService: ExecutorService = _
+ private var rpcEnv: RpcEnv = _
@Mock
- private var executorPodFactory: ExecutorPodFactory = _
+ private var driverEndpointRef: RpcEndpointRef = _
@Mock
private var kubernetesClient: KubernetesClient = _
@@ -103,338 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
private var podOperations: PODS = _
@Mock
- private var podsWithLabelOperations: LABELED_PODS = _
+ private var labeledPods: LABELED_PODS = _
@Mock
- private var podsInNamespace: IN_NAMESPACE_PODS = _
+ private var taskScheduler: TaskSchedulerImpl = _
@Mock
- private var podsWithDriverName: PodResource[Pod, DoneablePod] = _
+ private var eventQueue: ExecutorPodsSnapshotsStore = _
@Mock
- private var rpcEnv: RpcEnv = _
+ private var podAllocator: ExecutorPodsAllocator = _
@Mock
- private var driverEndpointRef: RpcEndpointRef = _
+ private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _
@Mock
- private var executorPodsWatch: Watch = _
+ private var watchEvents: ExecutorPodsWatchSnapshotSource = _
@Mock
- private var successFuture: Future[Boolean] = _
+ private var pollEvents: ExecutorPodsPollingSnapshotSource = _
- private var sparkConf: SparkConf = _
- private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _
- private var allocatorRunnable: ArgumentCaptor[Runnable] = _
- private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _
private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _
-
- private val driverPod = new PodBuilder()
- .withNewMetadata()
- .withName(DRIVER_POD_NAME)
- .addToLabels(SPARK_APP_ID_LABEL, APP_ID)
- .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE)
- .endMetadata()
- .build()
+ private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _
before {
MockitoAnnotations.initMocks(this)
- sparkConf = new SparkConf()
- .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME)
- .set(KUBERNETES_NAMESPACE, NAMESPACE)
- .set("spark.driver.host", SPARK_DRIVER_HOST)
- .set("spark.driver.port", SPARK_DRIVER_PORT.toString)
- .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL)
- executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]])
- allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable])
- requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable])
+ when(taskScheduler.sc).thenReturn(sc)
+ when(sc.conf).thenReturn(sparkConf)
driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint])
- when(sparkContext.conf).thenReturn(sparkConf)
- when(sparkContext.listenerBus).thenReturn(listenerBus)
- when(taskSchedulerImpl.sc).thenReturn(sparkContext)
- when(kubernetesClient.pods()).thenReturn(podOperations)
- when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations)
- when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture()))
- .thenReturn(executorPodsWatch)
- when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace)
- when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName)
- when(podsWithDriverName.get()).thenReturn(driverPod)
- when(allocatorExecutor.scheduleWithFixedDelay(
- allocatorRunnable.capture(),
- mockitoEq(0L),
- mockitoEq(TimeUnit.MINUTES.toMillis(1)),
- mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null)
- // Creating Futures in Scala backed by a Java executor service resolves to running
- // ExecutorService#execute (as opposed to submit)
- doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture())
when(rpcEnv.setupEndpoint(
mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture()))
.thenReturn(driverEndpointRef)
-
- // Used by the CoarseGrainedSchedulerBackend when making RPC calls.
- when(driverEndpointRef.ask[Boolean]
- (any(classOf[Any]))
- (any())).thenReturn(successFuture)
- when(successFuture.failed).thenReturn(Future[Throwable] {
- // emulate behavior of the Future.failed method.
- throw new NoSuchElementException()
- }(ThreadUtils.sameThread))
- }
-
- test("Basic lifecycle expectations when starting and stopping the scheduler.") {
- val scheduler = newSchedulerBackend()
- scheduler.start()
- assert(executorPodsWatcherArgument.getValue != null)
- assert(allocatorRunnable.getValue != null)
- scheduler.stop()
- verify(executorPodsWatch).close()
- }
-
- test("Static allocation should request executors upon first allocator run.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
- val scheduler = newSchedulerBackend()
- scheduler.start()
- requestExecutorRunnable.getValue.run()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
- when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
- allocatorRunnable.getValue.run()
- verify(podOperations).create(firstResolvedPod)
- verify(podOperations).create(secondResolvedPod)
- }
-
- test("Killing executors deletes the executor pods") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
- val scheduler = newSchedulerBackend()
- scheduler.start()
- requestExecutorRunnable.getValue.run()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
- when(podOperations.create(any(classOf[Pod])))
- .thenAnswer(AdditionalAnswers.returnsFirstArg())
- allocatorRunnable.getValue.run()
- scheduler.doKillExecutors(Seq("2"))
- requestExecutorRunnable.getAllValues.asScala.last.run()
- verify(podOperations).delete(secondResolvedPod)
- verify(podOperations, never()).delete(firstResolvedPod)
- }
-
- test("Executors should be requested in batches.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2)
- val scheduler = newSchedulerBackend()
- scheduler.start()
- requestExecutorRunnable.getValue.run()
- when(podOperations.create(any(classOf[Pod])))
- .thenAnswer(AdditionalAnswers.returnsFirstArg())
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
- allocatorRunnable.getValue.run()
- verify(podOperations).create(firstResolvedPod)
- verify(podOperations, never()).create(secondResolvedPod)
- val registerFirstExecutorMessage = RegisterExecutor(
- "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String])
- when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
- driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
- .apply(registerFirstExecutorMessage)
- allocatorRunnable.getValue.run()
- verify(podOperations).create(secondResolvedPod)
- }
-
- test("Scaled down executors should be cleaned up") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
- val scheduler = newSchedulerBackend()
- scheduler.start()
-
- // The scheduler backend spins up one executor pod.
- requestExecutorRunnable.getValue.run()
- when(podOperations.create(any(classOf[Pod])))
- .thenAnswer(AdditionalAnswers.returnsFirstArg())
- val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- allocatorRunnable.getValue.run()
- val executorEndpointRef = mock[RpcEndpointRef]
- when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
- val registerFirstExecutorMessage = RegisterExecutor(
- "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
- when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
- driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
- .apply(registerFirstExecutorMessage)
-
- // Request that there are 0 executors and trigger deletion from driver.
- scheduler.doRequestTotalExecutors(0)
- requestExecutorRunnable.getAllValues.asScala.last.run()
- scheduler.doKillExecutors(Seq("1"))
- requestExecutorRunnable.getAllValues.asScala.last.run()
- verify(podOperations, times(1)).delete(resolvedPod)
- driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
-
- val exitedPod = exitPod(resolvedPod, 0)
- executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod)
- allocatorRunnable.getValue.run()
-
- // No more deletion attempts of the executors.
- // This is graceful termination and should not be detected as a failure.
- verify(podOperations, times(1)).delete(resolvedPod)
- verify(driverEndpointRef, times(1)).send(
- RemoveExecutor("1", ExecutorExited(
- 0,
- exitCausedByApp = false,
- s"Container in pod ${exitedPod.getMetadata.getName} exited from" +
- s" explicit termination request.")))
- }
-
- test("Executors that fail should not be deleted.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
-
- val scheduler = newSchedulerBackend()
- scheduler.start()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
- requestExecutorRunnable.getValue.run()
- allocatorRunnable.getValue.run()
- val executorEndpointRef = mock[RpcEndpointRef]
- when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
- val registerFirstExecutorMessage = RegisterExecutor(
- "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
- when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
- driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
- .apply(registerFirstExecutorMessage)
- driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
- executorPodsWatcherArgument.getValue.eventReceived(
- Action.ERROR, exitPod(firstResolvedPod, 1))
-
- // A replacement executor should be created but the error pod should persist.
- val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
- scheduler.doRequestTotalExecutors(1)
- requestExecutorRunnable.getValue.run()
- allocatorRunnable.getAllValues.asScala.last.run()
- verify(podOperations, never()).delete(firstResolvedPod)
- verify(driverEndpointRef).send(
- RemoveExecutor("1", ExecutorExited(
- 1,
- exitCausedByApp = true,
- s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" +
- " exit status code 1.")))
- }
-
- test("Executors disconnected due to unknown reasons are deleted and replaced.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
- val executorLostReasonCheckMaxAttempts = sparkConf.get(
- KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS)
-
- val scheduler = newSchedulerBackend()
- scheduler.start()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg())
- requestExecutorRunnable.getValue.run()
- allocatorRunnable.getValue.run()
- val executorEndpointRef = mock[RpcEndpointRef]
- when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000))
- val registerFirstExecutorMessage = RegisterExecutor(
- "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String])
- when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty)
- driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext])
- .apply(registerFirstExecutorMessage)
-
- driverEndpoint.getValue.onDisconnected(executorEndpointRef.address)
- 1 to executorLostReasonCheckMaxAttempts foreach { _ =>
- allocatorRunnable.getValue.run()
- verify(podOperations, never()).delete(FIRST_EXECUTOR_POD)
+ when(kubernetesClient.pods()).thenReturn(podOperations)
+ schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend(
+ taskScheduler,
+ rpcEnv,
+ kubernetesClient,
+ requestExecutorsService,
+ eventQueue,
+ podAllocator,
+ lifecycleEventHandler,
+ watchEvents,
+ pollEvents) {
+ override def applicationId(): String = TEST_SPARK_APP_ID
}
-
- val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD)
- allocatorRunnable.getValue.run()
- verify(podOperations).delete(firstResolvedPod)
- verify(driverEndpointRef).send(
- RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons.")))
}
- test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
- val scheduler = newSchedulerBackend()
- scheduler.start()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- when(podOperations.create(firstResolvedPod))
- .thenThrow(new RuntimeException("test"))
- requestExecutorRunnable.getValue.run()
- allocatorRunnable.getValue.run()
- verify(podOperations, times(1)).create(firstResolvedPod)
- val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD)
- allocatorRunnable.getValue.run()
- verify(podOperations).create(recreatedResolvedPod)
+ test("Start all components") {
+ schedulerBackendUnderTest.start()
+ verify(podAllocator).setTotalExpectedExecutors(3)
+ verify(podAllocator).start(TEST_SPARK_APP_ID)
+ verify(lifecycleEventHandler).start(schedulerBackendUnderTest)
+ verify(watchEvents).start(TEST_SPARK_APP_ID)
+ verify(pollEvents).start(TEST_SPARK_APP_ID)
}
- test("Executors that are initially created but the watch notices them fail are rebuilt" +
- " in the next batch.") {
- sparkConf
- .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1)
- .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1)
- val scheduler = newSchedulerBackend()
- scheduler.start()
- val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD)
- when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg())
- requestExecutorRunnable.getValue.run()
- allocatorRunnable.getValue.run()
- verify(podOperations, times(1)).create(firstResolvedPod)
- executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod)
- val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD)
- allocatorRunnable.getValue.run()
- verify(podOperations).create(recreatedResolvedPod)
+ test("Stop all components") {
+ when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods)
+ when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods)
+ schedulerBackendUnderTest.stop()
+ verify(eventQueue).stop()
+ verify(watchEvents).stop()
+ verify(pollEvents).stop()
+ verify(labeledPods).delete()
+ verify(kubernetesClient).close()
}
- private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = {
- new KubernetesClusterSchedulerBackend(
- taskSchedulerImpl,
- rpcEnv,
- executorPodFactory,
- kubernetesClient,
- allocatorExecutor,
- requestExecutorsService) {
-
- override def applicationId(): String = APP_ID
- }
+ test("Remove executor") {
+ schedulerBackendUnderTest.start()
+ schedulerBackendUnderTest.doRemoveExecutor(
+ "1", ExecutorKilled)
+ verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled))
}
- private def exitPod(basePod: Pod, exitCode: Int): Pod = {
- new PodBuilder(basePod)
- .editStatus()
- .addNewContainerStatus()
- .withNewState()
- .withNewTerminated()
- .withExitCode(exitCode)
- .endTerminated()
- .endState()
- .endContainerStatus()
- .endStatus()
- .build()
+ test("Kill executors") {
+ schedulerBackendUnderTest.start()
+ when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods)
+ when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods)
+ when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods)
+ schedulerBackendUnderTest.doKillExecutors(Seq("1", "2"))
+ verify(labeledPods, never()).delete()
+ requestExecutorsService.runNextPendingCommand()
+ verify(labeledPods).delete()
}
- private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = {
- val resolvedPod = new PodBuilder(expectedPod)
- .editMetadata()
- .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString)
- .endMetadata()
- .build()
- when(executorPodFactory.createExecutorPod(
- executorId.toString,
- APP_ID,
- DRIVER_URL,
- sparkConf.getExecutorEnv,
- driverPod,
- Map.empty)).thenReturn(resolvedPod)
- resolvedPod
+ test("Request total executors") {
+ schedulerBackendUnderTest.start()
+ schedulerBackendUnderTest.doRequestTotalExecutors(5)
+ verify(podAllocator).setTotalExpectedExecutors(3)
+ verify(podAllocator, never()).setTotalExpectedExecutors(5)
+ requestExecutorsService.runNextPendingCommand()
+ verify(podAllocator).setTotalExpectedExecutors(5)
}
+
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
new file mode 100644
index 0000000000000..a6bc8bce32926
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster.k8s
+
+import io.fabric8.kubernetes.api.model.PodBuilder
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
+
+class KubernetesExecutorBuilderSuite extends SparkFunSuite {
+ private val BASIC_STEP_TYPE = "basic"
+ private val SECRETS_STEP_TYPE = "mount-secrets"
+ private val ENV_SECRETS_STEP_TYPE = "env-secrets"
+ private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
+
+ private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep])
+ private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
+ private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
+ private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
+
+ private val builderUnderTest = new KubernetesExecutorBuilder(
+ _ => basicFeatureStep,
+ _ => mountSecretsStep,
+ _ => envSecretsStep,
+ _ => localDirsStep)
+
+ test("Basic steps are consistently applied.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesExecutorSpecificConf(
+ "executor-id", new PodBuilder().build()),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE)
+ }
+
+ test("Apply secrets step if secrets are present.") {
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesExecutorSpecificConf(
+ "executor-id", new PodBuilder().build()),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map("secret" -> "secretMountPath"),
+ Map("secret-name" -> "secret-key"),
+ Map.empty,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ SECRETS_STEP_TYPE,
+ ENV_SECRETS_STEP_TYPE)
+ }
+
+ private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = {
+ assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size)
+ stepTypes.foreach { stepType =>
+ assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType)
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
index 491b7cf692478..9badf8556afc3 100644
--- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
@@ -40,7 +40,6 @@ RUN set -ex && \
COPY ${spark_jars} /opt/spark/jars
COPY bin /opt/spark/bin
COPY sbin /opt/spark/sbin
-COPY conf /opt/spark/conf
COPY ${img_path}/spark/entrypoint.sh /opt/
COPY examples /opt/spark/examples
COPY data /opt/spark/data
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile
new file mode 100644
index 0000000000000..72bb9620b45de
--- /dev/null
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile
@@ -0,0 +1,39 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+ARG base_img
+FROM $base_img
+WORKDIR /
+RUN mkdir ${SPARK_HOME}/python
+COPY python/lib ${SPARK_HOME}/python/lib
+# TODO: Investigate running both pip and pip3 via virtualenvs
+RUN apk add --no-cache python && \
+ apk add --no-cache python3 && \
+ python -m ensurepip && \
+ python3 -m ensurepip && \
+ # We remove ensurepip since it adds no functionality since pip is
+ # installed on the image and it just takes up 1.6MB on the image
+ rm -r /usr/lib/python*/ensurepip && \
+ pip install --upgrade pip setuptools && \
+ # You may install with python3 packages by using pip3.6
+ # Removed the .cache to save space
+ rm -r /root/.cache
+
+ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip
+
+WORKDIR /opt/spark/work-dir
+ENTRYPOINT [ "/opt/entrypoint.sh" ]
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
index b9090dc2852a5..acdb4b1f09e0a 100755
--- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
@@ -22,7 +22,10 @@ set -ex
# Check whether there is a passwd entry for the container UID
myuid=$(id -u)
mygid=$(id -g)
+# turn off -e for getent because it will return error code in anonymous uid case
+set +e
uidentry=$(getent passwd $myuid)
+set -e
# If there is no passwd entry for the container UID, attempt to create one
if [ -z "$uidentry" ] ; then
@@ -41,7 +44,7 @@ fi
shift 1
SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*"
-env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt
+env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt
readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt
if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then
SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH"
@@ -50,17 +53,43 @@ if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then
cp -R "$SPARK_MOUNTED_FILES_DIR/." .
fi
+if [ -n "$PYSPARK_FILES" ]; then
+ PYTHONPATH="$PYTHONPATH:$PYSPARK_FILES"
+fi
+
+PYSPARK_ARGS=""
+if [ -n "$PYSPARK_APP_ARGS" ]; then
+ PYSPARK_ARGS="$PYSPARK_APP_ARGS"
+fi
+
+
+if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then
+ pyv="$(python -V 2>&1)"
+ export PYTHON_VERSION="${pyv:7}"
+ export PYSPARK_PYTHON="python"
+ export PYSPARK_DRIVER_PYTHON="python"
+elif [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "3" ]; then
+ pyv3="$(python3 -V 2>&1)"
+ export PYTHON_VERSION="${pyv3:7}"
+ export PYSPARK_PYTHON="python3"
+ export PYSPARK_DRIVER_PYTHON="python3"
+fi
+
case "$SPARK_K8S_CMD" in
driver)
CMD=(
- ${JAVA_HOME}/bin/java
- "${SPARK_JAVA_OPTS[@]}"
- -cp "$SPARK_CLASSPATH"
- -Xms$SPARK_DRIVER_MEMORY
- -Xmx$SPARK_DRIVER_MEMORY
- -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS
- $SPARK_DRIVER_CLASS
- $SPARK_DRIVER_ARGS
+ "$SPARK_HOME/bin/spark-submit"
+ --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS"
+ --deploy-mode client
+ "$@"
+ )
+ ;;
+ driver-py)
+ CMD=(
+ "$SPARK_HOME/bin/spark-submit"
+ --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS"
+ --deploy-mode client
+ "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS
)
;;
@@ -80,14 +109,6 @@ case "$SPARK_K8S_CMD" in
)
;;
- init)
- CMD=(
- "$SPARK_HOME/bin/spark-class"
- "org.apache.spark.deploy.k8s.SparkPodInitContainer"
- "$@"
- )
- ;;
-
*)
echo "Unknown command: $SPARK_K8S_CMD" 1>&2
exit 1
diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md
new file mode 100644
index 0000000000000..b3863e6b7d1af
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/README.md
@@ -0,0 +1,52 @@
+---
+layout: global
+title: Spark on Kubernetes Integration Tests
+---
+
+# Running the Kubernetes Integration Tests
+
+Note that the integration test framework is currently being heavily revised and
+is subject to change. Note that currently the integration tests only run with Java 8.
+
+The simplest way to run the integration tests is to install and run Minikube, then run the following:
+
+ dev/dev-run-integration-tests.sh
+
+The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should
+run with a minimum of 3 CPUs and 4G of memory:
+
+ minikube start --cpus 3 --memory 4096
+
+You can download Minikube [here](https://github.com/kubernetes/minikube/releases).
+
+# Integration test customization
+
+Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below.
+
+## Re-using Docker Images
+
+By default, the test framework will build new Docker images on every test execution. A unique image tag is generated,
+and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag
+that you have built by other means already, pass the tag to the test script:
+
+ dev/dev-run-integration-tests.sh --image-tag
+
+where if you still want to use images that were built before by the test framework:
+
+ dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt)
+
+## Spark Distribution Under Test
+
+The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball:
+
+* `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test.
+
+TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree.
+
+## Customizing the Namespace and Service Account
+
+* `--namespace ` - set `` to the namespace in which the tests should be run.
+* `--service-account ` - set `` to the name of the Kubernetes service account to
+use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch,
+and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account
+in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`.
diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh
new file mode 100755
index 0000000000000..ea893fa39eede
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh
@@ -0,0 +1,93 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests
+
+cd "${TEST_ROOT_DIR}"
+
+DEPLOY_MODE="minikube"
+IMAGE_REPO="docker.io/kubespark"
+SPARK_TGZ="N/A"
+IMAGE_TAG="N/A"
+SPARK_MASTER=
+NAMESPACE=
+SERVICE_ACCOUNT=
+
+# Parse arguments
+while (( "$#" )); do
+ case $1 in
+ --image-repo)
+ IMAGE_REPO="$2"
+ shift
+ ;;
+ --image-tag)
+ IMAGE_TAG="$2"
+ shift
+ ;;
+ --deploy-mode)
+ DEPLOY_MODE="$2"
+ shift
+ ;;
+ --spark-tgz)
+ SPARK_TGZ="$2"
+ shift
+ ;;
+ --spark-master)
+ SPARK_MASTER="$2"
+ shift
+ ;;
+ --namespace)
+ NAMESPACE="$2"
+ shift
+ ;;
+ --service-account)
+ SERVICE_ACCOUNT="$2"
+ shift
+ ;;
+ *)
+ break
+ ;;
+ esac
+ shift
+done
+
+cd $TEST_ROOT_DIR
+
+properties=(
+ -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \
+ -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \
+ -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \
+ -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE
+)
+
+if [ -n $NAMESPACE ];
+then
+ properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE )
+fi
+
+if [ -n $SERVICE_ACCOUNT ];
+then
+ properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT )
+fi
+
+if [ -n $SPARK_MASTER ];
+then
+ properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER )
+fi
+
+../../../build/mvn integration-test ${properties[@]}
diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml
new file mode 100644
index 0000000000000..a4c242f2f2645
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+apiVersion: v1
+kind: Namespace
+metadata:
+ name: spark
+---
+apiVersion: v1
+kind: ServiceAccount
+metadata:
+ name: spark-sa
+ namespace: spark
+---
+apiVersion: rbac.authorization.k8s.io/v1beta1
+kind: ClusterRole
+metadata:
+ name: spark-role
+rules:
+- apiGroups:
+ - ""
+ resources:
+ - "pods"
+ verbs:
+ - "*"
+---
+apiVersion: rbac.authorization.k8s.io/v1beta1
+kind: ClusterRoleBinding
+metadata:
+ name: spark-role-binding
+subjects:
+- kind: ServiceAccount
+ name: spark-sa
+ namespace: spark
+roleRef:
+ kind: ClusterRole
+ name: spark-role
+ apiGroup: rbac.authorization.k8s.io
\ No newline at end of file
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
new file mode 100644
index 0000000000000..520bda89e034d
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -0,0 +1,155 @@
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.4.0-SNAPSHOT
+ ../../../pom.xml
+
+
+ spark-kubernetes-integration-tests_2.11
+ spark-kubernetes-integration-tests
+
+ 1.3.0
+ 1.4.0
+
+ 3.0.0
+ 3.2.2
+ 1.0
+ kubernetes-integration-tests
+ ${project.build.directory}/spark-dist-unpacked
+ N/A
+ ${project.build.directory}/imageTag.txt
+ minikube
+ docker.io/kubespark
+
+
+ jar
+ Spark Project Kubernetes Integration Tests
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ io.fabric8
+ kubernetes-client
+ ${kubernetes-client.version}
+
+
+
+
+
+
+ org.codehaus.mojo
+ exec-maven-plugin
+ ${exec-maven-plugin.version}
+
+
+ setup-integration-test-env
+ pre-integration-test
+
+ exec
+
+
+ scripts/setup-integration-test-env.sh
+
+ --unpacked-spark-tgz
+ ${spark.kubernetes.test.unpackSparkDir}
+
+ --image-repo
+ ${spark.kubernetes.test.imageRepo}
+
+ --image-tag
+ ${spark.kubernetes.test.imageTag}
+
+ --image-tag-output-file
+ ${spark.kubernetes.test.imageTagFile}
+
+ --deploy-mode
+ ${spark.kubernetes.test.deployMode}
+
+ --spark-tgz
+ ${spark.kubernetes.test.sparkTgz}
+
+
+
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ ${scalatest-maven-plugin.version}
+
+ ${project.build.directory}/surefire-reports
+ .
+ SparkTestSuite.txt
+ -ea -Xmx3g -XX:ReservedCodeCacheSize=512m ${extraScalaTestArgs}
+
+
+ file:src/test/resources/log4j.properties
+ true
+ ${spark.kubernetes.test.imageTagFile}
+ ${spark.kubernetes.test.unpackSparkDir}
+ ${spark.kubernetes.test.imageRepo}
+ ${spark.kubernetes.test.deployMode}
+ ${spark.kubernetes.test.master}
+ ${spark.kubernetes.test.namespace}
+ ${spark.kubernetes.test.serviceAccountName}
+
+ ${test.exclude.tags}
+
+
+
+ test
+
+ test
+
+
+
+ (?<!Suite)
+
+
+
+ integration-test
+ integration-test
+
+ test
+
+
+
+
+
+
+
+
+
diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh
new file mode 100755
index 0000000000000..ccfb8e767c529
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+TEST_ROOT_DIR=$(git rev-parse --show-toplevel)
+UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked"
+IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt"
+DEPLOY_MODE="minikube"
+IMAGE_REPO="docker.io/kubespark"
+IMAGE_TAG="N/A"
+SPARK_TGZ="N/A"
+
+# Parse arguments
+while (( "$#" )); do
+ case $1 in
+ --unpacked-spark-tgz)
+ UNPACKED_SPARK_TGZ="$2"
+ shift
+ ;;
+ --image-repo)
+ IMAGE_REPO="$2"
+ shift
+ ;;
+ --image-tag)
+ IMAGE_TAG="$2"
+ shift
+ ;;
+ --image-tag-output-file)
+ IMAGE_TAG_OUTPUT_FILE="$2"
+ shift
+ ;;
+ --deploy-mode)
+ DEPLOY_MODE="$2"
+ shift
+ ;;
+ --spark-tgz)
+ SPARK_TGZ="$2"
+ shift
+ ;;
+ *)
+ break
+ ;;
+ esac
+ shift
+done
+
+if [[ $SPARK_TGZ == "N/A" ]];
+then
+ echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1;
+fi
+
+rm -rf $UNPACKED_SPARK_TGZ
+mkdir -p $UNPACKED_SPARK_TGZ
+tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ;
+
+if [[ $IMAGE_TAG == "N/A" ]];
+then
+ IMAGE_TAG=$(uuidgen);
+ cd $UNPACKED_SPARK_TGZ
+ if [[ $DEPLOY_MODE == cloud ]] ;
+ then
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build
+ if [[ $IMAGE_REPO == gcr.io* ]] ;
+ then
+ gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG
+ else
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push
+ fi
+ else
+ # -m option for minikube.
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build
+ fi
+ cd -
+fi
+
+rm -f $IMAGE_TAG_OUTPUT_FILE
+echo -n $IMAGE_TAG > $IMAGE_TAG_OUTPUT_FILE
diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties
new file mode 100644
index 0000000000000..866126bc3c1c2
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/integration-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/integration-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from a few verbose libraries.
+log4j.logger.com.sun.jersey=WARN
+log4j.logger.org.apache.hadoop=WARN
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.mortbay=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
new file mode 100644
index 0000000000000..65c513cf241a4
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+import java.io.File
+import java.nio.file.{Path, Paths}
+import java.util.UUID
+import java.util.regex.Pattern
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.PatternFilenameFilter
+import io.fabric8.kubernetes.api.model.{Container, Pod}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.concurrent.{Eventually, PatienceConfiguration}
+import org.scalatest.time.{Minutes, Seconds, Span}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory}
+import org.apache.spark.deploy.k8s.integrationtest.config._
+
+private[spark] class KubernetesSuite extends SparkFunSuite
+ with BeforeAndAfterAll with BeforeAndAfter {
+
+ import KubernetesSuite._
+
+ private var testBackend: IntegrationTestBackend = _
+ private var sparkHomeDir: Path = _
+ private var kubernetesTestComponents: KubernetesTestComponents = _
+ private var sparkAppConf: SparkAppConf = _
+ private var image: String = _
+ private var containerLocalSparkDistroExamplesJar: String = _
+ private var appLocator: String = _
+ private var driverPodName: String = _
+
+ override def beforeAll(): Unit = {
+ // The scalatest-maven-plugin gives system properties that are referenced but not set null
+ // values. We need to remove the null-value properties before initializing the test backend.
+ val nullValueProperties = System.getProperties.asScala
+ .filter(entry => entry._2.equals("null"))
+ .map(entry => entry._1.toString)
+ nullValueProperties.foreach { key =>
+ System.clearProperty(key)
+ }
+
+ val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir")
+ require(sparkDirProp != null, "Spark home directory must be provided in system properties.")
+ sparkHomeDir = Paths.get(sparkDirProp)
+ require(sparkHomeDir.toFile.isDirectory,
+ s"No directory found for spark home specified at $sparkHomeDir.")
+ val imageTag = getTestImageTag
+ val imageRepo = getTestImageRepo
+ image = s"$imageRepo/spark:$imageTag"
+
+ val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars"))
+ .toFile
+ .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0)
+ containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" +
+ s"${sparkDistroExamplesJarFile.getName}"
+ testBackend = IntegrationTestBackendFactory.getTestBackend
+ testBackend.initialize()
+ kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient)
+ }
+
+ override def afterAll(): Unit = {
+ testBackend.cleanUp()
+ }
+
+ before {
+ appLocator = UUID.randomUUID().toString.replaceAll("-", "")
+ driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "")
+ sparkAppConf = kubernetesTestComponents.newSparkAppConf()
+ .set("spark.kubernetes.container.image", image)
+ .set("spark.kubernetes.driver.pod.name", driverPodName)
+ .set("spark.kubernetes.driver.label.spark-app-locator", appLocator)
+ .set("spark.kubernetes.executor.label.spark-app-locator", appLocator)
+ if (!kubernetesTestComponents.hasUserSpecifiedNamespace) {
+ kubernetesTestComponents.createNamespace()
+ }
+ }
+
+ after {
+ if (!kubernetesTestComponents.hasUserSpecifiedNamespace) {
+ kubernetesTestComponents.deleteNamespace()
+ }
+ deleteDriverPod()
+ }
+
+ test("Run SparkPi with no resources") {
+ runSparkPiAndVerifyCompletion()
+ }
+
+ test("Run SparkPi with a very long application name.") {
+ sparkAppConf.set("spark.app.name", "long" * 40)
+ runSparkPiAndVerifyCompletion()
+ }
+
+ test("Run SparkPi with a master URL without a scheme.") {
+ val url = kubernetesTestComponents.kubernetesClient.getMasterUrl
+ val k8sMasterUrl = if (url.getPort < 0) {
+ s"k8s://${url.getHost}"
+ } else {
+ s"k8s://${url.getHost}:${url.getPort}"
+ }
+ sparkAppConf.set("spark.master", k8sMasterUrl)
+ runSparkPiAndVerifyCompletion()
+ }
+
+ test("Run SparkPi with an argument.") {
+ runSparkPiAndVerifyCompletion(appArgs = Array("5"))
+ }
+
+ test("Run SparkPi with custom labels, annotations, and environment variables.") {
+ sparkAppConf
+ .set("spark.kubernetes.driver.label.label1", "label1-value")
+ .set("spark.kubernetes.driver.label.label2", "label2-value")
+ .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value")
+ .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value")
+ .set("spark.kubernetes.driverEnv.ENV1", "VALUE1")
+ .set("spark.kubernetes.driverEnv.ENV2", "VALUE2")
+ .set("spark.kubernetes.executor.label.label1", "label1-value")
+ .set("spark.kubernetes.executor.label.label2", "label2-value")
+ .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value")
+ .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value")
+ .set("spark.executorEnv.ENV1", "VALUE1")
+ .set("spark.executorEnv.ENV2", "VALUE2")
+
+ runSparkPiAndVerifyCompletion(
+ driverPodChecker = (driverPod: Pod) => {
+ doBasicDriverPodCheck(driverPod)
+ checkCustomSettings(driverPod)
+ },
+ executorPodChecker = (executorPod: Pod) => {
+ doBasicExecutorPodCheck(executorPod)
+ checkCustomSettings(executorPod)
+ })
+ }
+
+ // TODO(ssuchter): Enable the below after debugging
+ // test("Run PageRank using remote data file") {
+ // sparkAppConf
+ // .set("spark.kubernetes.mountDependencies.filesDownloadDir",
+ // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH)
+ // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE)
+ // runSparkPageRankAndVerifyCompletion(
+ // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE))
+ // }
+
+ private def runSparkPiAndVerifyCompletion(
+ appResource: String = containerLocalSparkDistroExamplesJar,
+ driverPodChecker: Pod => Unit = doBasicDriverPodCheck,
+ executorPodChecker: Pod => Unit = doBasicExecutorPodCheck,
+ appArgs: Array[String] = Array.empty[String],
+ appLocator: String = appLocator): Unit = {
+ runSparkApplicationAndVerifyCompletion(
+ appResource,
+ SPARK_PI_MAIN_CLASS,
+ Seq("Pi is roughly 3"),
+ appArgs,
+ driverPodChecker,
+ executorPodChecker,
+ appLocator)
+ }
+
+ private def runSparkPageRankAndVerifyCompletion(
+ appResource: String = containerLocalSparkDistroExamplesJar,
+ driverPodChecker: Pod => Unit = doBasicDriverPodCheck,
+ executorPodChecker: Pod => Unit = doBasicExecutorPodCheck,
+ appArgs: Array[String],
+ appLocator: String = appLocator): Unit = {
+ runSparkApplicationAndVerifyCompletion(
+ appResource,
+ SPARK_PAGE_RANK_MAIN_CLASS,
+ Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"),
+ appArgs,
+ driverPodChecker,
+ executorPodChecker,
+ appLocator)
+ }
+
+ private def runSparkApplicationAndVerifyCompletion(
+ appResource: String,
+ mainClass: String,
+ expectedLogOnCompletion: Seq[String],
+ appArgs: Array[String],
+ driverPodChecker: Pod => Unit,
+ executorPodChecker: Pod => Unit,
+ appLocator: String): Unit = {
+ val appArguments = SparkAppArguments(
+ mainAppResource = appResource,
+ mainClass = mainClass,
+ appArgs = appArgs)
+ SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir)
+
+ val driverPod = kubernetesTestComponents.kubernetesClient
+ .pods()
+ .withLabel("spark-app-locator", appLocator)
+ .withLabel("spark-role", "driver")
+ .list()
+ .getItems
+ .get(0)
+ driverPodChecker(driverPod)
+
+ val executorPods = kubernetesTestComponents.kubernetesClient
+ .pods()
+ .withLabel("spark-app-locator", appLocator)
+ .withLabel("spark-role", "executor")
+ .list()
+ .getItems
+ executorPods.asScala.foreach { pod =>
+ executorPodChecker(pod)
+ }
+
+ Eventually.eventually(TIMEOUT, INTERVAL) {
+ expectedLogOnCompletion.foreach { e =>
+ assert(kubernetesTestComponents.kubernetesClient
+ .pods()
+ .withName(driverPod.getMetadata.getName)
+ .getLog
+ .contains(e), "The application did not complete.")
+ }
+ }
+ }
+
+ private def doBasicDriverPodCheck(driverPod: Pod): Unit = {
+ assert(driverPod.getMetadata.getName === driverPodName)
+ assert(driverPod.getSpec.getContainers.get(0).getImage === image)
+ assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver")
+ }
+
+ private def doBasicExecutorPodCheck(executorPod: Pod): Unit = {
+ assert(executorPod.getSpec.getContainers.get(0).getImage === image)
+ assert(executorPod.getSpec.getContainers.get(0).getName === "executor")
+ }
+
+ private def checkCustomSettings(pod: Pod): Unit = {
+ assert(pod.getMetadata.getLabels.get("label1") === "label1-value")
+ assert(pod.getMetadata.getLabels.get("label2") === "label2-value")
+ assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value")
+ assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value")
+
+ val container = pod.getSpec.getContainers.get(0)
+ val envVars = container
+ .getEnv
+ .asScala
+ .map { env =>
+ (env.getName, env.getValue)
+ }
+ .toMap
+ assert(envVars("ENV1") === "VALUE1")
+ assert(envVars("ENV2") === "VALUE2")
+ }
+
+ private def deleteDriverPod(): Unit = {
+ kubernetesTestComponents.kubernetesClient.pods().withName(driverPodName).delete()
+ Eventually.eventually(TIMEOUT, INTERVAL) {
+ assert(kubernetesTestComponents.kubernetesClient
+ .pods()
+ .withName(driverPodName)
+ .get() == null)
+ }
+ }
+}
+
+private[spark] object KubernetesSuite {
+
+ val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes))
+ val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds))
+ val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi"
+ val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank"
+
+ // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files"
+
+ // val REMOTE_PAGE_RANK_DATA_FILE =
+ // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt"
+ // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE =
+ // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt"
+
+ // case object ShuffleNotReadyException extends Exception
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala
new file mode 100644
index 0000000000000..48727142dd052
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+import java.nio.file.{Path, Paths}
+import java.util.UUID
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import io.fabric8.kubernetes.client.DefaultKubernetesClient
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.internal.Logging
+
+private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) {
+
+ val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace"))
+ val hasUserSpecifiedNamespace = namespaceOption.isDefined
+ val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", ""))
+ private val serviceAccountName =
+ Option(System.getProperty("spark.kubernetes.test.serviceAccountName"))
+ .getOrElse("default")
+ val kubernetesClient = defaultClient.inNamespace(namespace)
+ val clientConfig = kubernetesClient.getConfiguration
+
+ def createNamespace(): Unit = {
+ defaultClient.namespaces.createNew()
+ .withNewMetadata()
+ .withName(namespace)
+ .endMetadata()
+ .done()
+ }
+
+ def deleteNamespace(): Unit = {
+ defaultClient.namespaces.withName(namespace).delete()
+ Eventually.eventually(KubernetesSuite.TIMEOUT, KubernetesSuite.INTERVAL) {
+ val namespaceList = defaultClient
+ .namespaces()
+ .list()
+ .getItems
+ .asScala
+ require(!namespaceList.exists(_.getMetadata.getName == namespace))
+ }
+ }
+
+ def newSparkAppConf(): SparkAppConf = {
+ new SparkAppConf()
+ .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}")
+ .set("spark.kubernetes.namespace", namespace)
+ .set("spark.executor.memory", "500m")
+ .set("spark.executor.cores", "1")
+ .set("spark.executors.instances", "1")
+ .set("spark.app.name", "spark-test-app")
+ .set("spark.ui.enabled", "true")
+ .set("spark.testing", "false")
+ .set("spark.kubernetes.submission.waitAppCompletion", "false")
+ .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName)
+ }
+}
+
+private[spark] class SparkAppConf {
+
+ private val map = mutable.Map[String, String]()
+
+ def set(key: String, value: String): SparkAppConf = {
+ map.put(key, value)
+ this
+ }
+
+ def get(key: String): String = map.getOrElse(key, "")
+
+ def setJars(jars: Seq[String]): Unit = set("spark.jars", jars.mkString(","))
+
+ override def toString: String = map.toString
+
+ def toStringArray: Iterable[String] = map.toList.flatMap(t => List("--conf", s"${t._1}=${t._2}"))
+}
+
+private[spark] case class SparkAppArguments(
+ mainAppResource: String,
+ mainClass: String,
+ appArgs: Array[String])
+
+private[spark] object SparkAppLauncher extends Logging {
+
+ def launch(
+ appArguments: SparkAppArguments,
+ appConf: SparkAppConf,
+ timeoutSecs: Int,
+ sparkHomeDir: Path): Unit = {
+ val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit"))
+ logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf")
+ val appArgsArray =
+ if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" "))
+ else Array[String]()
+ val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath,
+ "--deploy-mode", "cluster",
+ "--class", appArguments.mainClass,
+ "--master", appConf.get("spark.master")
+ ) ++ appConf.toStringArray :+
+ appArguments.mainAppResource) ++
+ appArgsArray
+ ProcessUtils.executeProcess(commandLine, timeoutSecs)
+ }
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala
new file mode 100644
index 0000000000000..d8f3a6cec05c3
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable.ArrayBuffer
+import scala.io.Source
+
+import org.apache.spark.internal.Logging
+
+object ProcessUtils extends Logging {
+ /**
+ * executeProcess is used to run a command and return the output if it
+ * completes within timeout seconds.
+ */
+ def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = {
+ val pb = new ProcessBuilder().command(fullCommand: _*)
+ pb.redirectErrorStream(true)
+ val proc = pb.start()
+ val outputLines = new ArrayBuffer[String]
+ Utils.tryWithResource(proc.getInputStream)(
+ Source.fromInputStream(_, "UTF-8").getLines().foreach { line =>
+ logInfo(line)
+ outputLines += line
+ })
+ assert(proc.waitFor(timeout, TimeUnit.SECONDS),
+ s"Timed out while executing ${fullCommand.mkString(" ")}")
+ assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}")
+ outputLines
+ }
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala
new file mode 100644
index 0000000000000..f1fd6dc19ce54
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+import java.util.concurrent.TimeUnit
+
+import com.google.common.util.concurrent.SettableFuture
+import io.fabric8.kubernetes.api.model.HasMetadata
+import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher}
+import io.fabric8.kubernetes.client.Watcher.Action
+import io.fabric8.kubernetes.client.internal.readiness.Readiness
+
+private[spark] class SparkReadinessWatcher[T <: HasMetadata] extends Watcher[T] {
+
+ private val signal = SettableFuture.create[Boolean]
+
+ override def eventReceived(action: Action, resource: T): Unit = {
+ if ((action == Action.MODIFIED || action == Action.ADDED) &&
+ Readiness.isReady(resource)) {
+ signal.set(true)
+ }
+ }
+
+ override def onClose(cause: KubernetesClientException): Unit = {}
+
+ def waitUntilReady(): Boolean = signal.get(60, TimeUnit.SECONDS)
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala
new file mode 100644
index 0000000000000..663f8b6523ac8
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+import java.io.Closeable
+import java.net.URI
+
+import org.apache.spark.internal.Logging
+
+object Utils extends Logging {
+
+ def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = {
+ val resource = createResource
+ try f.apply(resource) finally resource.close()
+ }
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala
new file mode 100644
index 0000000000000..284712c6d250e
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.k8s.integrationtest.backend
+
+import io.fabric8.kubernetes.client.DefaultKubernetesClient
+
+import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend
+
+private[spark] trait IntegrationTestBackend {
+ def initialize(): Unit
+ def getKubernetesClient: DefaultKubernetesClient
+ def cleanUp(): Unit = {}
+}
+
+private[spark] object IntegrationTestBackendFactory {
+ val deployModeConfigKey = "spark.kubernetes.test.deployMode"
+
+ def getTestBackend: IntegrationTestBackend = {
+ val deployMode = Option(System.getProperty(deployModeConfigKey))
+ .getOrElse("minikube")
+ if (deployMode == "minikube") {
+ MinikubeTestBackend
+ } else {
+ throw new IllegalArgumentException(
+ "Invalid " + deployModeConfigKey + ": " + deployMode)
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala
new file mode 100644
index 0000000000000..6494cbc18f33e
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest.backend.minikube
+
+import java.io.File
+import java.nio.file.Paths
+
+import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient}
+
+import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils
+import org.apache.spark.internal.Logging
+
+// TODO support windows
+private[spark] object Minikube extends Logging {
+
+ private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60
+
+ def getMinikubeIp: String = {
+ val outputs = executeMinikube("ip")
+ .filter(_.matches("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$"))
+ assert(outputs.size == 1, "Unexpected amount of output from minikube ip")
+ outputs.head
+ }
+
+ def getMinikubeStatus: MinikubeStatus.Value = {
+ val statusString = executeMinikube("status")
+ .filter(line => line.contains("minikubeVM: ") || line.contains("minikube:"))
+ .head
+ .replaceFirst("minikubeVM: ", "")
+ .replaceFirst("minikube: ", "")
+ MinikubeStatus.unapply(statusString)
+ .getOrElse(throw new IllegalStateException(s"Unknown status $statusString"))
+ }
+
+ def getKubernetesClient: DefaultKubernetesClient = {
+ val kubernetesMaster = s"https://${getMinikubeIp}:8443"
+ val userHome = System.getProperty("user.home")
+ val kubernetesConf = new ConfigBuilder()
+ .withApiVersion("v1")
+ .withMasterUrl(kubernetesMaster)
+ .withCaCertFile(Paths.get(userHome, ".minikube", "ca.crt").toFile.getAbsolutePath)
+ .withClientCertFile(Paths.get(userHome, ".minikube", "apiserver.crt").toFile.getAbsolutePath)
+ .withClientKeyFile(Paths.get(userHome, ".minikube", "apiserver.key").toFile.getAbsolutePath)
+ .build()
+ new DefaultKubernetesClient(kubernetesConf)
+ }
+
+ private def executeMinikube(action: String, args: String*): Seq[String] = {
+ ProcessUtils.executeProcess(
+ Array("bash", "-c", s"minikube $action") ++ args, MINIKUBE_STARTUP_TIMEOUT_SECONDS)
+ }
+}
+
+private[spark] object MinikubeStatus extends Enumeration {
+
+ // The following states are listed according to
+ // https://github.com/docker/machine/blob/master/libmachine/state/state.go.
+ val STARTING = status("Starting")
+ val RUNNING = status("Running")
+ val PAUSED = status("Paused")
+ val STOPPING = status("Stopping")
+ val STOPPED = status("Stopped")
+ val ERROR = status("Error")
+ val TIMEOUT = status("Timeout")
+ val SAVED = status("Saved")
+ val NONE = status("")
+
+ def status(value: String): Value = new Val(nextId, value)
+ def unapply(s: String): Option[Value] = values.find(s == _.toString)
+}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala
new file mode 100644
index 0000000000000..cb9324179d70e
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest.backend.minikube
+
+import io.fabric8.kubernetes.client.DefaultKubernetesClient
+
+import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend
+
+private[spark] object MinikubeTestBackend extends IntegrationTestBackend {
+
+ private var defaultClient: DefaultKubernetesClient = _
+
+ override def initialize(): Unit = {
+ val minikubeStatus = Minikube.getMinikubeStatus
+ require(minikubeStatus == MinikubeStatus.RUNNING,
+ s"Minikube must be running to use the Minikube backend for integration tests." +
+ s" Current status is: $minikubeStatus.")
+ defaultClient = Minikube.getKubernetesClient
+ }
+
+ override def cleanUp(): Unit = {
+ super.cleanUp()
+ }
+
+ override def getKubernetesClient: DefaultKubernetesClient = {
+ defaultClient
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala
similarity index 50%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala
rename to resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala
index 0daa7b95e8aae..a81ef455c6766 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala
@@ -14,23 +14,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
+package org.apache.spark.deploy.k8s.integrationtest
-import org.apache.spark.deploy.k8s.MountSecretsBootstrap
+import java.io.File
-/**
- * An init-container configuration step for mounting user-specified secrets onto user-specified
- * paths.
- *
- * @param bootstrap a utility actually handling mounting of the secrets
- */
-private[spark] class InitContainerMountSecretsStep(
- bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep {
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+
+package object config {
+ def getTestImageTag: String = {
+ val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile")
+ require(imageTagFileProp != null, "Image tag file must be provided in system properties.")
+ val imageTagFile = new File(imageTagFileProp)
+ require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.")
+ Files.toString(imageTagFile, Charsets.UTF_8).trim
+ }
- override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = {
- // Mount the secret volumes given that the volumes have already been added to the driver pod
- // when mounting the secrets into the main driver container.
- val initContainer = bootstrap.mountSecrets(spec.initContainer)
- spec.copy(initContainer = initContainer)
+ def getTestImageRepo: String = {
+ val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo")
+ require(imageRepo != null, "Image repo must be provided in system properties.")
+ imageRepo
}
}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala
new file mode 100644
index 0000000000000..0807a68cd823c
--- /dev/null
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.integrationtest
+
+package object constants {
+ val MINIKUBE_TEST_BACKEND = "minikube"
+ val GCE_TEST_BACKEND = "gce"
+}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
index aa378c9d340f1..ccf33e8d4283c 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos
import java.util.concurrent.CountDownLatch
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.mesos.config._
import org.apache.spark.deploy.mesos.ui.MesosClusterUI
import org.apache.spark.deploy.rest.mesos.MesosRestServer
@@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher
Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler)
Utils.initDaemon(log)
val conf = new SparkConf
- val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf)
+ val dispatcherArgs = try {
+ new MesosClusterDispatcherArguments(args, conf)
+ } catch {
+ case e: SparkException =>
+ printErrorAndExit(e.getMessage())
+ null
+ }
conf.setMaster(dispatcherArgs.masterUrl)
conf.setAppName(dispatcherArgs.name)
dispatcherArgs.zookeeperUrl.foreach { z =>
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
index 096bb4e1af688..267a4283db9e6 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.SparkConf
+import org.apache.spark.deploy.SparkSubmitUtils
import org.apache.spark.util.{IntParam, Utils}
private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) {
@@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
parse(tail)
case ("--conf") :: value :: tail =>
- val pair = MesosClusterDispatcher.
- parseSparkConfProperty(value)
- confProperties(pair._1) = pair._2
+ val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value)
+ confProperties(k) = v
parse(tail)
case ("--help") :: tail =>
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index 022191d0070fd..91f64141e5318 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
Cannot find driver {driverId}
- return UIUtils.basicSparkPage(content, s"Details for Job $driverId")
+ return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId")
}
val driverState = state.get
val driverHeaders = Seq("Driver property", "Value")
@@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
retryHeaders, retryRow, Iterable.apply(driverState.description.retryState))
val content =
@@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
;
- UIUtils.basicSparkPage(content, s"Details for Job $driverId")
+ UIUtils.basicSparkPage(request, content, s"Details for Job $driverId")
}
private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = {
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
index 88a6614d51384..c53285331ea68 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
@@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage(
{retryTable}
;
- UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster")
+ UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster")
}
private def queuedRow(submission: MesosDriverDescription): Seq[Node] = {
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
index 604978967d6db..15bbe60d6c8fb 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala
@@ -40,7 +40,7 @@ private[spark] class MesosClusterUI(
override def initialize() {
attachPage(new MesosClusterPage(this))
attachPage(new DriverPage(this))
- attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static"))
+ addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR)
}
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index d224a7325820a..7d80eedcc43ce 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable
import org.apache.mesos.Protos.TaskStatus.Reason
import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState}
-import org.apache.spark.deploy.mesos.MesosDriverDescription
-import org.apache.spark.deploy.mesos.config
+import org.apache.spark.deploy.mesos.{config, MesosDriverDescription}
import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.Utils
@@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler(
envBuilder.build()
}
+ private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = {
+ val isLocalJar = desc.jarUrl.startsWith("local://")
+ val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists {
+ case "container" => true
+ case "host" => false
+ case other =>
+ logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.")
+ false
+ }
+ isLocalJar && isContainerLocal
+ }
+
private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = {
val confUris = List(conf.getOption("spark.mesos.uris"),
desc.conf.getOption("spark.mesos.uris"),
@@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler(
_.map(_.split(",").map(_.trim))
).flatten
- val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:")
-
- ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri =>
- CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build())
+ if (isContainerLocalAppJar(desc)) {
+ (confUris ++ getDriverExecutorURI(desc).toList).map(uri =>
+ CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build())
+ } else {
+ val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:")
+ ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri =>
+ CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build())
+ }
}
private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = {
@@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler(
(cmdExecutable, ".")
}
val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ")
- val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString()
+ val primaryResource = {
+ if (isContainerLocalAppJar(desc)) {
+ new File(desc.jarUrl.stripPrefix("local://")).toString()
+ } else {
+ new File(sandboxPath, desc.jarUrl.split("/").last).toString()
+ }
+ }
+
val appArguments = desc.command.arguments.mkString(" ")
s"$executable $cmdOptions $primaryResource $appArguments"
@@ -530,9 +552,9 @@ private[spark] class MesosClusterScheduler(
.filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) }
.toMap
(defaultConf ++ driverConf).foreach { case (key, value) =>
- options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) }
+ options ++= Seq("--conf", s"${key}=${value}") }
- options
+ options.map(shellEscape)
}
/**
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index 53f5f61cca486..d35bea4aca311 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
+ val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map {
+ Utils.substituteAppNExecIds(_, appId, taskId)
+ }.getOrElse("")
// Set the environment variable through a command prefix
// to append to the existing value of the variable
@@ -632,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
slave.hostname,
externalShufflePort,
sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs",
- s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"),
+ s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"),
sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s"))
slave.shuffleRegistered = true
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
index d6d939d246109..71a70ff048ccc 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
@@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map {
+ Utils.substituteAppNExecIds(_, appId, execId)
+ }.getOrElse("")
val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
Utils.libraryPathEnvPrefix(Seq(p))
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala
index 7165bfae18a5e..a1bf4f0c048fe 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala
@@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens
+import org.apache.spark.ui.UIUtils
import org.apache.spark.util.ThreadUtils
@@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager(
val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds)
logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}")
- (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75))
+ (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf))
} catch {
case e: Exception =>
logError(s"Failed to fetch Hadoop delegation tokens $e")
@@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager(
} catch {
case e: Exception =>
// Log the error and try to write new tokens back in an hour
- logWarning("Couldn't broadcast tokens, trying again in an hour", e)
- credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS)
+ val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT))
+ logWarning(
+ s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e)
+ credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS)
return
}
scheduleRenewal(this)
@@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager(
"related configurations in the target services.")
currTime
} else {
- SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75)
+ SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf)
}
logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms")
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index e75450369ad85..ecbcc960fc5a0 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.cluster.mesos
+import java.io.File
+import java.nio.charset.StandardCharsets
import java.util.{List => JList}
import java.util.concurrent.CountDownLatch
@@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.base.Splitter
+import com.google.common.io.Files
import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver}
import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
import org.apache.mesos.Protos.FrameworkInfo.Capability
@@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging {
failoverTimeout: Option[Double] = None,
frameworkId: Option[String] = None): SchedulerDriver = {
val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName)
- val credBuilder = Credential.newBuilder()
+ fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(
+ conf.get(DRIVER_HOST_ADDRESS)))
webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) }
checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) }
failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) }
frameworkId.foreach { id =>
fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build())
}
- fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(
- conf.get(DRIVER_HOST_ADDRESS)))
- conf.getOption("spark.mesos.principal").foreach { principal =>
- fwInfoBuilder.setPrincipal(principal)
- credBuilder.setPrincipal(principal)
- }
- conf.getOption("spark.mesos.secret").foreach { secret =>
- credBuilder.setSecret(secret)
- }
- if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) {
- throw new SparkException(
- "spark.mesos.principal must be configured when spark.mesos.secret is set")
- }
+
conf.getOption("spark.mesos.role").foreach { role =>
fwInfoBuilder.setRole(role)
}
@@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging {
if (maxGpus > 0) {
fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES))
}
+ val credBuilder = buildCredentials(conf, fwInfoBuilder)
if (credBuilder.hasPrincipal) {
new MesosSchedulerDriver(
scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build())
@@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging {
}
}
+ def buildCredentials(
+ conf: SparkConf,
+ fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = {
+ val credBuilder = Credential.newBuilder()
+ conf.getOption("spark.mesos.principal")
+ .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL")))
+ .orElse(
+ conf.getOption("spark.mesos.principal.file")
+ .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE")))
+ .map { principalFile =>
+ Files.toString(new File(principalFile), StandardCharsets.UTF_8)
+ }
+ ).foreach { principal =>
+ fwInfoBuilder.setPrincipal(principal)
+ credBuilder.setPrincipal(principal)
+ }
+ conf.getOption("spark.mesos.secret")
+ .orElse(Option(conf.getenv("SPARK_MESOS_SECRET")))
+ .orElse(
+ conf.getOption("spark.mesos.secret.file")
+ .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE")))
+ .map { secretFile =>
+ Files.toString(new File(secretFile), StandardCharsets.UTF_8)
+ }
+ ).foreach { secret =>
+ credBuilder.setSecret(secret)
+ }
+ if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) {
+ throw new SparkException(
+ "spark.mesos.principal must be configured when spark.mesos.secret is set")
+ }
+ credBuilder
+ }
+
/**
* Starts the MesosSchedulerDriver and stores the current running driver to this new instance.
* This driver is expected to not be running.
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
index 7df738958f85c..8d90e1a8591ad 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
@@ -17,16 +17,20 @@
package org.apache.spark.scheduler.cluster.mesos
+import java.io.{File, FileNotFoundException}
+
import scala.collection.JavaConverters._
import scala.language.reflectiveCalls
-import org.apache.mesos.Protos.{Resource, Value}
+import com.google.common.io.Files
+import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value}
import org.mockito.Mockito._
import org.scalatest._
import org.scalatest.mockito.MockitoSugar
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.internal.config._
+import org.apache.spark.util.SparkConfWithEnv
class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar {
@@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS
val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1}
portsToUse.isEmpty shouldBe true
}
+
+ test("Principal specified via spark.mesos.principal") {
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal", "test-principal")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ }
+
+ test("Principal specified via spark.mesos.principal.file") {
+ val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt");
+ pFile.deleteOnExit()
+ Files.write("test-principal".getBytes("UTF-8"), pFile);
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal.file", pFile.getAbsolutePath())
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ }
+
+ test("Principal specified via spark.mesos.principal.file that does not exist") {
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal.file", "/tmp/does-not-exist")
+
+ intercept[FileNotFoundException] {
+ utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ }
+ }
+
+ test("Principal specified via SPARK_MESOS_PRINCIPAL") {
+ val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal"))
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ }
+
+ test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") {
+ val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt");
+ pFile.deleteOnExit()
+ Files.write("test-principal".getBytes("UTF-8"), pFile);
+ val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath()))
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ }
+
+ test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") {
+ val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist"))
+
+ intercept[FileNotFoundException] {
+ utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ }
+ }
+
+ test("Secret specified via spark.mesos.secret") {
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal", "test-principal")
+ conf.set("spark.mesos.secret", "my-secret")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ credBuilder.hasSecret shouldBe true
+ credBuilder.getSecret shouldBe "my-secret"
+ }
+
+ test("Principal specified via spark.mesos.secret.file") {
+ val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt");
+ sFile.deleteOnExit()
+ Files.write("my-secret".getBytes("UTF-8"), sFile);
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal", "test-principal")
+ conf.set("spark.mesos.secret.file", sFile.getAbsolutePath())
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ credBuilder.hasSecret shouldBe true
+ credBuilder.getSecret shouldBe "my-secret"
+ }
+
+ test("Principal specified via spark.mesos.secret.file that does not exist") {
+ val conf = new SparkConf()
+ conf.set("spark.mesos.principal", "test-principal")
+ conf.set("spark.mesos.secret.file", "/tmp/does-not-exist")
+
+ intercept[FileNotFoundException] {
+ utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ }
+ }
+
+ test("Principal specified via SPARK_MESOS_SECRET") {
+ val env = Map("SPARK_MESOS_SECRET" -> "my-secret")
+ val conf = new SparkConfWithEnv(env)
+ conf.set("spark.mesos.principal", "test-principal")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ credBuilder.hasSecret shouldBe true
+ credBuilder.getSecret shouldBe "my-secret"
+ }
+
+ test("Principal specified via SPARK_MESOS_SECRET_FILE") {
+ val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt");
+ sFile.deleteOnExit()
+ Files.write("my-secret".getBytes("UTF-8"), sFile);
+
+ val sFilePath = sFile.getAbsolutePath()
+ val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath)
+ val conf = new SparkConfWithEnv(env)
+ conf.set("spark.mesos.principal", "test-principal")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ credBuilder.hasSecret shouldBe true
+ credBuilder.getSecret shouldBe "my-secret"
+ }
+
+ test("Secret specified with no principal") {
+ val conf = new SparkConf()
+ conf.set("spark.mesos.secret", "my-secret")
+
+ intercept[SparkException] {
+ utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ }
+ }
+
+ test("Principal specification preference") {
+ val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal"))
+ conf.set("spark.mesos.principal", "test-principal")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ }
+
+ test("Secret specification preference") {
+ val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret"))
+ conf.set("spark.mesos.principal", "test-principal")
+ conf.set("spark.mesos.secret", "my-secret")
+
+ val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder())
+ credBuilder.hasPrincipal shouldBe true
+ credBuilder.getPrincipal shouldBe "test-principal"
+ credBuilder.hasSecret shouldBe true
+ credBuilder.getSecret shouldBe "my-secret"
+ }
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 2f88feb0f1fdf..3d6ee50b070a3 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy.yarn
import java.io.{File, IOException}
-import java.lang.reflect.InvocationTargetException
+import java.lang.reflect.{InvocationTargetException, Modifier}
import java.net.{Socket, URI, URL}
import java.security.PrivilegedExceptionAction
import java.util.concurrent.{TimeoutException, TimeUnit}
@@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.StringUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
@@ -41,7 +40,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.yarn.config._
-import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager}
+import org.apache.spark.deploy.yarn.security.AMCredentialRenewer
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.rpc._
@@ -79,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf))
- private val ugi = {
- val original = UserGroupInformation.getCurrentUser()
-
- // If a principal and keytab were provided, log in to kerberos, and set up a thread to
- // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL
- // of the TGT, use a configuration to define how often to check that a relogin is necessary.
- // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed.
- val principal = sparkConf.get(PRINCIPAL).orNull
- val keytab = sparkConf.get(KEYTAB).orNull
- if (principal != null && keytab != null) {
- UserGroupInformation.loginUserFromKeytab(principal, keytab)
-
- val renewer = new Thread() {
- override def run(): Unit = Utils.tryLogNonFatalError {
- while (true) {
- TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD))
- UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab()
- }
- }
+ private val userClassLoader = {
+ val classpath = Client.getUserClasspath(sparkConf)
+ val urls = classpath.map { entry =>
+ new URL("file:" + new File(entry.getPath()).getAbsolutePath())
+ }
+
+ if (isClusterMode) {
+ if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
+ new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ } else {
+ new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
}
- renewer.setName("am-kerberos-renewer")
- renewer.setDaemon(true)
- renewer.start()
-
- // Transfer the original user's tokens to the new user, since that's needed to connect to
- // YARN. It also copies over any delegation tokens that might have been created by the
- // client, which will then be transferred over when starting executors (until new ones
- // are created by the periodic task).
- val newUser = UserGroupInformation.getCurrentUser()
- SparkHadoopUtil.get.transferCredentials(original, newUser)
- newUser
} else {
- SparkHadoopUtil.get.createSparkUser()
+ new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
}
}
+ private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ =>
+ new AMCredentialRenewer(sparkConf, yarnConf)
+ }
+
+ private val ugi = credentialRenewer match {
+ case Some(cr) =>
+ // Set the context class loader so that the token renewer has access to jars distributed
+ // by the user.
+ val currentLoader = Thread.currentThread().getContextClassLoader()
+ Thread.currentThread().setContextClassLoader(userClassLoader)
+ try {
+ cr.start()
+ } finally {
+ Thread.currentThread().setContextClassLoader(currentLoader)
+ }
+
+ case _ =>
+ SparkHadoopUtil.get.createSparkUser()
+ }
+
private val client = doAsUser { new YarnRMClient() }
// Default to twice the number of executors (twice the maximum number of executors if dynamic
@@ -148,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
// A flag to check whether user has initialized spark context
@volatile private var registered = false
- private val userClassLoader = {
- val classpath = Client.getUserClasspath(sparkConf)
- val urls = classpath.map { entry =>
- new URL("file:" + new File(entry.getPath()).getAbsolutePath())
- }
-
- if (isClusterMode) {
- if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
- new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
- } else {
- new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
- }
- } else {
- new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
- }
- }
-
// Lock for controlling the allocator (heartbeat) thread.
private val allocatorLock = new Object()
@@ -189,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
// In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
private val sparkContextPromise = Promise[SparkContext]()
- private var credentialRenewer: AMCredentialRenewer = _
-
// Load the list of localized files set by the client. This is used when launching executors,
// and is loaded here so that these configs don't pollute the Web UI's environment page in
// cluster mode.
@@ -316,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
}
}
- // If the credentials file config is present, we must periodically renew tokens. So create
- // a new AMDelegationTokenRenewer
- if (sparkConf.contains(CREDENTIALS_FILE_PATH)) {
- // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the
- // classloader so that main jar and secondary jars could be used by AMCredentialRenewer.
- val credentialRenewerThread = new Thread {
- setName("AMCredentialRenewerStarter")
- setContextClassLoader(userClassLoader)
-
- override def run(): Unit = {
- val credentialManager = new YARNHadoopDelegationTokenManager(
- sparkConf,
- yarnConf,
- conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
-
- val credentialRenewer =
- new AMCredentialRenewer(sparkConf, yarnConf, credentialManager)
- credentialRenewer.scheduleLoginFromKeytab()
- }
- }
-
- credentialRenewerThread.start()
- credentialRenewerThread.join()
- }
-
if (isClusterMode) {
runDriver()
} else {
@@ -352,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
logError("Uncaught exception: ", e)
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION,
- "Uncaught exception: " + e)
+ "Uncaught exception: " + StringUtils.stringifyException(e))
}
}
@@ -390,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
synchronized {
if (!finished) {
val inShutdown = ShutdownHookManager.inShutdown()
- if (registered) {
+ if (registered || !isClusterMode) {
exitCode = code
finalStatus = status
} else {
@@ -409,53 +365,69 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
logDebug("shutting down user thread")
userClassThread.interrupt()
}
- if (!inShutdown && credentialRenewer != null) {
- credentialRenewer.stop()
- credentialRenewer = null
+ if (!inShutdown) {
+ credentialRenewer.foreach(_.stop())
}
}
}
}
private def sparkContextInitialized(sc: SparkContext) = {
- sparkContextPromise.success(sc)
+ sparkContextPromise.synchronized {
+ // Notify runDriver function that SparkContext is available
+ sparkContextPromise.success(sc)
+ // Pause the user class thread in order to make proper initialization in runDriver function.
+ sparkContextPromise.wait()
+ }
+ }
+
+ private def resumeDriver(): Unit = {
+ // When initialization in runDriver happened the user class thread has to be resumed.
+ sparkContextPromise.synchronized {
+ sparkContextPromise.notify()
+ }
}
private def registerAM(
+ host: String,
+ port: Int,
_sparkConf: SparkConf,
- _rpcEnv: RpcEnv,
- driverRef: RpcEndpointRef,
- uiAddress: Option[String]) = {
+ uiAddress: Option[String]): Unit = {
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress = ApplicationMaster
.getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId)
- val driverUrl = RpcEndpointAddress(
- _sparkConf.get("spark.driver.host"),
- _sparkConf.get("spark.driver.port").toInt,
+ client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress)
+ registered = true
+ }
+
+ private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = {
+ val appId = client.getAttemptId().getApplicationId().toString()
+ val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port,
CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
// Before we initialize the allocator, let's log the information about how executors will
// be run up front, to avoid printing this out for every single executor being launched.
// Use placeholders for information that changes such as executor IDs.
logInfo {
- val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
- val executorCores = sparkConf.get(EXECUTOR_CORES)
- val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "",
+ val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt
+ val executorCores = _sparkConf.get(EXECUTOR_CORES)
+ val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "",
"", executorMemory, executorCores, appId, securityMgr, localResources)
dummyRunner.launchContextDebugInfo()
}
- allocator = client.register(driverUrl,
- driverRef,
+ allocator = client.createAllocator(
yarnConf,
_sparkConf,
- uiAddress,
- historyAddress,
+ driverUrl,
+ driverRef,
securityMgr,
localResources)
+ credentialRenewer.foreach(_.setDriverRef(driverRef))
+
// Initialize the AM endpoint *after* the allocator has been initialized. This ensures
// that when the driver sends an initial executor request (e.g. after an AM restart),
// the allocator is ready to service requests.
@@ -465,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
reporterThread = launchReporterThread()
}
- /**
- * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
- */
- private def createSchedulerRef(host: String, port: String): RpcEndpointRef = {
- rpcEnv.setupEndpointRef(
- RpcAddress(host, port.toInt),
- YarnSchedulerBackend.ENDPOINT_NAME)
- }
-
private def runDriver(): Unit = {
addAmIpFilter(None)
userClassThread = startUserApplication()
@@ -487,16 +450,22 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
Duration(totalWaitTime, TimeUnit.MILLISECONDS))
if (sc != null) {
rpcEnv = sc.env.rpcEnv
- val driverRef = createSchedulerRef(
- sc.getConf.get("spark.driver.host"),
- sc.getConf.get("spark.driver.port"))
- registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl))
- registered = true
+
+ val userConf = sc.getConf
+ val host = userConf.get("spark.driver.host")
+ val port = userConf.get("spark.driver.port").toInt
+ registerAM(host, port, userConf, sc.ui.map(_.webUrl))
+
+ val driverRef = rpcEnv.setupEndpointRef(
+ RpcAddress(host, port),
+ YarnSchedulerBackend.ENDPOINT_NAME)
+ createAllocator(driverRef, userConf)
} else {
// Sanity check; should never happen in normal operation, since sc should only be null
// if the user app did not create a SparkContext.
throw new IllegalStateException("User did not initialize spark context!")
}
+ resumeDriver()
userClassThread.join()
} catch {
case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
@@ -506,6 +475,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_SC_NOT_INITED,
"Timed out waiting for SparkContext.")
+ } finally {
+ resumeDriver()
}
}
@@ -514,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
- val driverRef = waitForSparkDriver()
+
+ // The client-mode AM doesn't listen for incoming connections, so report an invalid port.
+ registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress"))
+
+ // The driver should be up and listening, so unlike cluster mode, just try to connect to it
+ // with no waiting or retrying.
+ val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0))
+ val driverRef = rpcEnv.setupEndpointRef(
+ RpcAddress(driverHost, driverPort),
+ YarnSchedulerBackend.ENDPOINT_NAME)
addAmIpFilter(Some(driverRef))
- registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"))
- registered = true
+ createAllocator(driverRef, sparkConf)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
@@ -628,40 +607,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
}
}
- private def waitForSparkDriver(): RpcEndpointRef = {
- logInfo("Waiting for Spark driver to be reachable.")
- var driverUp = false
- val hostport = args.userArgs(0)
- val (driverHost, driverPort) = Utils.parseHostPort(hostport)
-
- // Spark driver should already be up since it launched us, but we don't want to
- // wait forever, so wait 100 seconds max to match the cluster mode setting.
- val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME)
- val deadline = System.currentTimeMillis + totalWaitTimeMs
-
- while (!driverUp && !finished && System.currentTimeMillis < deadline) {
- try {
- val socket = new Socket(driverHost, driverPort)
- socket.close()
- logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
- driverUp = true
- } catch {
- case e: Exception =>
- logError("Failed to connect to driver at %s:%s, retrying ...".
- format(driverHost, driverPort))
- Thread.sleep(100L)
- }
- }
-
- if (!driverUp) {
- throw new SparkException("Failed to connect to driver!")
- }
-
- sparkConf.set("spark.driver.host", driverHost)
- sparkConf.set("spark.driver.port", driverPort.toString)
- createSchedulerRef(driverHost, driverPort.toString)
- }
-
/** Add the Yarn IP filter that is required for properly securing the UI. */
private def addAmIpFilter(driver: Option[RpcEndpointRef]) = {
val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
@@ -703,9 +648,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val userThread = new Thread {
override def run() {
try {
- mainMethod.invoke(null, userArgs.toArray)
- finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
- logDebug("Done running users class")
+ if (!Modifier.isStatic(mainMethod.getModifiers)) {
+ logError(s"Could not find static main method in object ${args.userClass}")
+ finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS)
+ } else {
+ mainMethod.invoke(null, userArgs.toArray)
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+ logDebug("Done running user class")
+ }
} catch {
case e: InvocationTargetException =>
e.getCause match {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8cd3cd9746a3a..7225ff03dc34e 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -93,11 +93,21 @@ private[spark] class Client(
private val distCacheMgr = new ClientDistributedCacheManager()
- private var loginFromKeytab = false
- private var principal: String = null
- private var keytab: String = null
- private var credentials: Credentials = null
- private var amKeytabFileName: String = null
+ private val principal = sparkConf.get(PRINCIPAL).orNull
+ private val keytab = sparkConf.get(KEYTAB).orNull
+ private val loginFromKeytab = principal != null
+ private val amKeytabFileName: String = {
+ require((principal == null) == (keytab == null),
+ "Both principal and keytab must be defined, or neither.")
+ if (loginFromKeytab) {
+ logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab")
+ // Generate a file name that can be used for the keytab file, that does not conflict
+ // with any user file.
+ new File(keytab).getName() + "-" + UUID.randomUUID().toString
+ } else {
+ null
+ }
+ }
private val launcherBackend = new LauncherBackend() {
override protected def conf: SparkConf = sparkConf
@@ -120,11 +130,6 @@ private[spark] class Client(
private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) }
.getOrElse(FileSystem.get(hadoopConf).getHomeDirectory())
- private val credentialManager = new YARNHadoopDelegationTokenManager(
- sparkConf,
- hadoopConf,
- conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
-
def reportLauncherState(state: SparkAppHandle.State): Unit = {
launcherBackend.setState(state)
}
@@ -145,9 +150,6 @@ private[spark] class Client(
var appId: ApplicationId = null
try {
launcherBackend.connect()
- // Setup the credentials before doing anything else,
- // so we have don't have issues at any point.
- setupCredentials()
yarnClient.init(hadoopConf)
yarnClient.start()
@@ -288,8 +290,26 @@ private[spark] class Client(
appContext
}
- /** Set up security tokens for launching our ApplicationMaster container. */
+ /**
+ * Set up security tokens for launching our ApplicationMaster container.
+ *
+ * This method will obtain delegation tokens from all the registered providers, and set them in
+ * the AM's launch context.
+ */
private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = {
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf)
+ credentialManager.obtainDelegationTokens(hadoopConf, credentials)
+
+ // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid
+ // that for regular users, since in those case the user already has access to the TGT,
+ // and adding delegation tokens could lead to expired or cancelled tokens being used
+ // later, as reported in SPARK-15754.
+ val currentUser = UserGroupInformation.getCurrentUser()
+ if (SparkHadoopUtil.get.isProxyUser(currentUser)) {
+ currentUser.addCredentials(credentials)
+ }
+
val dob = new DataOutputBuffer
credentials.writeTokenStorageToStream(dob)
amContainer.setTokens(ByteBuffer.wrap(dob.getData))
@@ -384,36 +404,6 @@ private[spark] class Client(
// and add them as local resources to the application master.
val fs = destDir.getFileSystem(hadoopConf)
- // Merge credentials obtained from registered providers
- val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials)
-
- if (credentials != null) {
- // Add credentials to current user's UGI, so that following operations don't need to use the
- // Kerberos tgt to get delegations again in the client side.
- val currentUser = UserGroupInformation.getCurrentUser()
- if (SparkHadoopUtil.get.isProxyUser(currentUser)) {
- currentUser.addCredentials(credentials)
- }
- logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n"))
- }
-
- // If we use principal and keytab to login, also credentials can be renewed some time
- // after current time, we should pass the next renewal and updating time to credential
- // renewer and updater.
- if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() &&
- nearestTimeOfNextRenewal != Long.MaxValue) {
-
- // Valid renewal time is 75% of next renewal time, and the valid update time will be
- // slightly later then renewal time (80% of next renewal time). This is to make sure
- // credentials are renewed and updated before expired.
- val currTime = System.currentTimeMillis()
- val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime
- val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime
-
- sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong)
- sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong)
- }
-
// Used to keep track of URIs added to the distributed cache. If the same URI is added
// multiple times, YARN will fail to launch containers for the app with an internal
// error.
@@ -696,7 +686,13 @@ private[spark] class Client(
}
}
- Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
+ // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's
+ // environments do not interfere with tests. This allows a special env variable during
+ // tests so that custom conf dirs can be used by unit tests.
+ val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++
+ (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil)
+
+ confDirs.foreach { envKey =>
sys.env.get(envKey).foreach { path =>
val dir = new File(path)
if (dir.isDirectory()) {
@@ -753,7 +749,7 @@ private[spark] class Client(
// Save the YARN configuration into a separate file that will be overlayed on top of the
// cluster's Hadoop conf.
- confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE))
+ confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE))
hadoopConf.writeXml(confStream)
confStream.closeEntry()
@@ -787,11 +783,6 @@ private[spark] class Client(
populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH))
env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
- if (loginFromKeytab) {
- val credentialsFile = "credentials-" + UUID.randomUUID().toString
- sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString)
- logInfo(s"Credentials file set to: $credentialsFile")
- }
// Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
val amEnvPrefix = "spark.yarn.appMasterEnv."
@@ -901,7 +892,9 @@ private[spark] class Client(
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts =>
- javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ javaOpts ++= Utils.splitCommandString(opts)
+ .map(Utils.substituteAppId(_, appId.toString))
+ .map(YarnSparkHadoopUtil.escapeForShell)
}
val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH),
sys.props.get("spark.driver.libraryPath")).flatten
@@ -923,7 +916,9 @@ private[spark] class Client(
s"(was '$opts'). Use spark.yarn.am.memory instead."
throw new SparkException(msg)
}
- javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ javaOpts ++= Utils.splitCommandString(opts)
+ .map(Utils.substituteAppId(_, appId.toString))
+ .map(YarnSparkHadoopUtil.escapeForShell)
}
sparkConf.get(AM_LIBRARY_PATH).foreach { paths =>
prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths))))
@@ -1008,25 +1003,6 @@ private[spark] class Client(
amContainer
}
- def setupCredentials(): Unit = {
- loginFromKeytab = sparkConf.contains(PRINCIPAL.key)
- if (loginFromKeytab) {
- principal = sparkConf.get(PRINCIPAL).get
- keytab = sparkConf.get(KEYTAB).orNull
-
- require(keytab != null, "Keytab must be specified when principal is specified.")
- logInfo("Attempting to login to the Kerberos" +
- s" using principal: $principal and keytab: $keytab")
- val f = new File(keytab)
- // Generate a file name that can be used for the keytab file, that does not conflict
- // with any user file.
- amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString
- sparkConf.set(PRINCIPAL.key, principal)
- }
- // Defensive copy of the credentials
- credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials)
- }
-
/**
* Report the state of an application until it has exited, either successfully or
* due to some failure, then return a pair of the yarn application state (FINISHED, FAILED,
@@ -1043,8 +1019,7 @@ private[spark] class Client(
appId: ApplicationId,
returnOnRunning: Boolean = false,
logApplicationReport: Boolean = true,
- interval: Long = sparkConf.get(REPORT_INTERVAL)):
- (YarnApplicationState, FinalApplicationStatus) = {
+ interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = {
var lastState: YarnApplicationState = null
while (true) {
Thread.sleep(interval)
@@ -1055,11 +1030,13 @@ private[spark] class Client(
case e: ApplicationNotFoundException =>
logError(s"Application $appId not found.")
cleanupStagingDir(appId)
- return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None)
case NonFatal(e) =>
- logError(s"Failed to contact YARN for application $appId.", e)
+ val msg = s"Failed to contact YARN for application $appId."
+ logError(msg, e)
// Don't necessarily clean up staging dir because status is unknown
- return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED)
+ return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED,
+ Some(msg))
}
val state = report.getYarnApplicationState
@@ -1097,14 +1074,14 @@ private[spark] class Client(
}
if (state == YarnApplicationState.FINISHED ||
- state == YarnApplicationState.FAILED ||
- state == YarnApplicationState.KILLED) {
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
cleanupStagingDir(appId)
- return (state, report.getFinalApplicationStatus)
+ return createAppReport(report)
}
if (returnOnRunning && state == YarnApplicationState.RUNNING) {
- return (state, report.getFinalApplicationStatus)
+ return createAppReport(report)
}
lastState = state
@@ -1153,16 +1130,17 @@ private[spark] class Client(
throw new SparkException(s"Application $appId finished with status: $state")
}
} else {
- val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId)
- if (yarnApplicationState == YarnApplicationState.FAILED ||
- finalApplicationStatus == FinalApplicationStatus.FAILED) {
+ val YarnAppReport(appState, finalState, diags) = monitorApplication(appId)
+ if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) {
+ diags.foreach { err =>
+ logError(s"Application diagnostics message: $err")
+ }
throw new SparkException(s"Application $appId finished with failed status")
}
- if (yarnApplicationState == YarnApplicationState.KILLED ||
- finalApplicationStatus == FinalApplicationStatus.KILLED) {
+ if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) {
throw new SparkException(s"Application $appId is killed")
}
- if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) {
+ if (finalState == FinalApplicationStatus.UNDEFINED) {
throw new SparkException(s"The final status of application $appId is undefined")
}
}
@@ -1176,7 +1154,7 @@ private[spark] class Client(
val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
require(pyArchivesFile.exists(),
s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.")
- val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip")
+ val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip")
require(py4jFile.exists(),
s"$py4jFile not found; cannot run pyspark application in YARN mode.")
Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
@@ -1220,10 +1198,6 @@ private object Client extends Logging {
// Name of the file in the conf archive containing Spark configuration.
val SPARK_CONF_FILE = "__spark_conf__.properties"
- // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the
- // cluster's Hadoop config.
- val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml"
-
// Subdirectory where the user's python files (not archives) will be placed.
val LOCALIZED_PYTHON_DIR = "__pyfiles__"
@@ -1505,6 +1479,12 @@ private object Client extends Logging {
uri.startsWith(s"$LOCAL_SCHEME:")
}
+ def createAppReport(report: ApplicationReport): YarnAppReport = {
+ val diags = report.getDiagnostics()
+ val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None
+ YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt)
+ }
+
}
private[spark] class YarnClusterApplication extends SparkApplication {
@@ -1519,3 +1499,8 @@ private[spark] class YarnClusterApplication extends SparkApplication {
}
}
+
+private[spark] case class YarnAppReport(
+ appState: YarnApplicationState,
+ finalState: FinalApplicationStatus,
+ diagnostics: Option[String])
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 3f4d236571ffd..a2a18cdff65af 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable(
// Set extra Java options for the executor, if defined
sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts =>
- javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
+ val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId)
+ javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell)
}
sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p =>
prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p))))
@@ -220,12 +221,6 @@ private[yarn] class ExecutorRunnable(
val env = new HashMap[String, String]()
Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH))
- sparkConf.getExecutorEnv.foreach { case (key, value) =>
- // This assumes each executor environment variable set here is a path
- // This is kept for backward compatibility and consistency with hadoop
- YarnSparkHadoopUtil.addPathToEnvironment(env, key, value)
- }
-
// lookup appropriate http scheme for container log urls
val yarnHttpPolicy = conf.get(
YarnConfiguration.YARN_HTTP_POLICY_KEY,
@@ -233,6 +228,20 @@ private[yarn] class ExecutorRunnable(
)
val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
+ System.getenv().asScala.filterKeys(_.startsWith("SPARK"))
+ .foreach { case (k, v) => env(k) = v }
+
+ sparkConf.getExecutorEnv.foreach { case (key, value) =>
+ if (key == Environment.CLASSPATH.name()) {
+ // If the key of env variable is CLASSPATH, we assume it is a path and append it.
+ // This is kept for backward compatibility and consistency with hadoop
+ YarnSparkHadoopUtil.addPathToEnvironment(env, key, value)
+ } else {
+ // For other env variables, simply overwrite the value.
+ env(key) = value
+ }
+ }
+
// Add log urls
container.foreach { c =>
sys.env.get("SPARK_USER").foreach { user =>
@@ -245,8 +254,6 @@ private[yarn] class ExecutorRunnable(
}
}
- System.getenv().asScala.filterKeys(_.startsWith("SPARK"))
- .foreach { case (k, v) => env(k) = v }
env
}
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 506adb363aa90..ebee3d431744d 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -81,7 +81,8 @@ private[yarn] class YarnAllocator(
private val releasedContainers = Collections.newSetFromMap[ContainerId](
new ConcurrentHashMap[ContainerId, java.lang.Boolean])
- private val numExecutorsRunning = new AtomicInteger(0)
+ private val runningExecutors = Collections.newSetFromMap[String](
+ new ConcurrentHashMap[String, java.lang.Boolean]())
private val numExecutorsStarting = new AtomicInteger(0)
@@ -166,7 +167,7 @@ private[yarn] class YarnAllocator(
clock = newClock
}
- def getNumExecutorsRunning: Int = numExecutorsRunning.get()
+ def getNumExecutorsRunning: Int = runningExecutors.size()
def getNumExecutorsFailed: Int = synchronized {
val endTime = clock.getTimeMillis()
@@ -242,12 +243,11 @@ private[yarn] class YarnAllocator(
* Request that the ResourceManager release the container running the specified executor.
*/
def killExecutor(executorId: String): Unit = synchronized {
- if (executorIdToContainer.contains(executorId)) {
- val container = executorIdToContainer.get(executorId).get
- internalReleaseContainer(container)
- numExecutorsRunning.decrementAndGet()
- } else {
- logWarning(s"Attempted to kill unknown executor $executorId!")
+ executorIdToContainer.get(executorId) match {
+ case Some(container) if !releasedContainers.contains(container.getId) =>
+ internalReleaseContainer(container)
+ runningExecutors.remove(executorId)
+ case _ => logWarning(s"Attempted to kill unknown executor $executorId!")
}
}
@@ -274,7 +274,7 @@ private[yarn] class YarnAllocator(
"Launching executor count: %d. Cluster resources: %s.")
.format(
allocatedContainers.size,
- numExecutorsRunning.get,
+ runningExecutors.size,
numExecutorsStarting.get,
allocateResponse.getAvailableResources))
@@ -286,7 +286,7 @@ private[yarn] class YarnAllocator(
logDebug("Completed %d containers".format(completedContainers.size))
processCompletedContainers(completedContainers.asScala)
logDebug("Finished processing %d completed containers. Current running executor count: %d."
- .format(completedContainers.size, numExecutorsRunning.get))
+ .format(completedContainers.size, runningExecutors.size))
}
}
@@ -300,9 +300,9 @@ private[yarn] class YarnAllocator(
val pendingAllocate = getPendingAllocate
val numPendingAllocate = pendingAllocate.size
val missing = targetNumExecutors - numPendingAllocate -
- numExecutorsStarting.get - numExecutorsRunning.get
+ numExecutorsStarting.get - runningExecutors.size
logDebug(s"Updating resource requests, target: $targetNumExecutors, " +
- s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " +
+ s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " +
s"executorsStarting: ${numExecutorsStarting.get}")
if (missing > 0) {
@@ -502,7 +502,7 @@ private[yarn] class YarnAllocator(
s"for executor with ID $executorId")
def updateInternalState(): Unit = synchronized {
- numExecutorsRunning.incrementAndGet()
+ runningExecutors.add(executorId)
numExecutorsStarting.decrementAndGet()
executorIdToContainer(executorId) = container
containerIdToExecutorId(container.getId) = executorId
@@ -513,7 +513,7 @@ private[yarn] class YarnAllocator(
allocatedContainerToHostMap.put(containerId, executorHostname)
}
- if (numExecutorsRunning.get < targetNumExecutors) {
+ if (runningExecutors.size() < targetNumExecutors) {
numExecutorsStarting.incrementAndGet()
if (launchContainers) {
launcherPool.execute(new Runnable {
@@ -554,7 +554,7 @@ private[yarn] class YarnAllocator(
} else {
logInfo(("Skip launching executorRunnable as running executors count: %d " +
"reached target executors count: %d.").format(
- numExecutorsRunning.get, targetNumExecutors))
+ runningExecutors.size, targetNumExecutors))
}
}
}
@@ -569,7 +569,11 @@ private[yarn] class YarnAllocator(
val exitReason = if (!alreadyReleased) {
// Decrement the number of executors running. The next iteration of
// the ApplicationMaster's reporting thread will take care of allocating.
- numExecutorsRunning.decrementAndGet()
+ containerIdToExecutorId.get(containerId) match {
+ case Some(executorId) => runningExecutors.remove(executorId)
+ case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}")
+ }
+
logInfo("Completed container %s%s (state: %s, exit status: %s)".format(
containerId,
onHostStr,
@@ -736,7 +740,8 @@ private object YarnAllocator {
def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = {
val matcher = pattern.matcher(diagnostics)
val diag = if (matcher.find()) " " + matcher.group() + "." else ""
- ("Container killed by YARN for exceeding memory limits." + diag
- + " Consider boosting spark.yarn.executor.memoryOverhead.")
+ s"Container killed by YARN for exceeding memory limits. $diag " +
+ "Consider boosting spark.yarn.executor.memoryOverhead or " +
+ "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714."
}
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index c1ae12aabb8cc..b59dcf158d87c 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.util.Utils
/**
* Handles registering and unregistering the application with the YARN ResourceManager.
@@ -43,23 +42,20 @@ private[spark] class YarnRMClient extends Logging {
/**
* Registers the application master with the RM.
*
+ * @param driverHost Host name where driver is running.
+ * @param driverPort Port where driver is listening.
* @param conf The Yarn configuration.
* @param sparkConf The Spark configuration.
* @param uiAddress Address of the SparkUI.
* @param uiHistoryAddress Address of the application on the History Server.
- * @param securityMgr The security manager.
- * @param localResources Map with information about files distributed via YARN's cache.
*/
def register(
- driverUrl: String,
- driverRef: RpcEndpointRef,
+ driverHost: String,
+ driverPort: Int,
conf: YarnConfiguration,
sparkConf: SparkConf,
uiAddress: Option[String],
- uiHistoryAddress: String,
- securityMgr: SecurityManager,
- localResources: Map[String, LocalResource]
- ): YarnAllocator = {
+ uiHistoryAddress: String): Unit = {
amClient = AMRMClient.createAMRMClient()
amClient.init(conf)
amClient.start()
@@ -71,9 +67,19 @@ private[spark] class YarnRMClient extends Logging {
logInfo("Registering the ApplicationMaster")
synchronized {
- amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl)
+ amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl)
registered = true
}
+ }
+
+ def createAllocator(
+ conf: YarnConfiguration,
+ sparkConf: SparkConf,
+ driverUrl: String,
+ driverRef: RpcEndpointRef,
+ securityMgr: SecurityManager,
+ localResources: Map[String, LocalResource]): YarnAllocator = {
+ require(registered, "Must register AM before creating allocator.")
new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr,
localResources, new SparkRackResolver())
}
@@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging {
if (registered) {
amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
}
+ if (amClient != null) {
+ amClient.stop()
+ }
}
/** Returns the attempt ID. */
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index f406fabd61860..7250e58b6c49a 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.yarn.config._
-import org.apache.spark.deploy.yarn.security.CredentialUpdater
import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager
import org.apache.spark.internal.config._
import org.apache.spark.launcher.YarnCommandBuilderUtils
@@ -38,8 +37,6 @@ import org.apache.spark.util.Utils
object YarnSparkHadoopUtil {
- private var credentialUpdater: CredentialUpdater = _
-
// Additional memory overhead
// 10% was arrived at experimentally. In the interest of minimizing memory waste while covering
// the common cases. Memory overhead tends to grow with container size.
@@ -203,24 +200,29 @@ object YarnSparkHadoopUtil {
.map(new Path(_).getFileSystem(hadoopConf))
.getOrElse(FileSystem.get(hadoopConf))
- filesystemsToAccess + stagingFS
- }
-
- def startCredentialUpdater(sparkConf: SparkConf): Unit = {
- val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
- val credentialManager = new YARNHadoopDelegationTokenManager(
- sparkConf,
- hadoopConf,
- conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
- credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager)
- credentialUpdater.start()
- }
-
- def stopCredentialUpdater(): Unit = {
- if (credentialUpdater != null) {
- credentialUpdater.stop()
- credentialUpdater = null
+ // Add the list of available namenodes for all namespaces in HDFS federation.
+ // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its
+ // namespaces.
+ val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") {
+ Set.empty
+ } else {
+ val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices")
+ // Retrieving the filesystem for the nameservices where HA is not enabled
+ val filesystemsWithoutHA = nameservices.flatMap { ns =>
+ Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode =>
+ new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf)
+ }
+ }
+ // Retrieving the filesystem for the nameservices where HA is enabled
+ val filesystemsWithHA = nameservices.flatMap { ns =>
+ Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ =>
+ new Path(s"hdfs://$ns").getFileSystem(hadoopConf)
+ }
+ }
+ (filesystemsWithoutHA ++ filesystemsWithHA).toSet
}
+
+ filesystemsToAccess ++ hadoopFilesystems + stagingFS
}
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
index 3ba3ae5ab4401..1a99b3bd57672 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -231,16 +231,6 @@ package object config {
/* Security configuration. */
- private[spark] val CREDENTIAL_FILE_MAX_COUNT =
- ConfigBuilder("spark.yarn.credentials.file.retention.count")
- .intConf
- .createWithDefault(5)
-
- private[spark] val CREDENTIALS_FILE_MAX_RETENTION =
- ConfigBuilder("spark.yarn.credentials.file.retention.days")
- .intConf
- .createWithDefault(5)
-
private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes")
.doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " +
"fs.defaultFS does not need to be listed here.")
@@ -271,11 +261,6 @@ package object config {
/* Private configs. */
- private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file")
- .internal()
- .stringConf
- .createWithDefault(null)
-
// Internal config to propagate the location of the user's jar to the driver/executors
private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar")
.internal()
@@ -329,16 +314,6 @@ package object config {
.stringConf
.createOptional
- private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime")
- .internal()
- .timeConf(TimeUnit.MILLISECONDS)
- .createWithDefault(Long.MaxValue)
-
- private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime")
- .internal()
- .timeConf(TimeUnit.MILLISECONDS)
- .createWithDefault(Long.MaxValue)
-
private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("1m")
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
index eaf2cff111a49..bc8d47dbd54c6 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala
@@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security
import java.security.PrivilegedExceptionAction
import java.util.concurrent.{ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.atomic.AtomicReference
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.deploy.security.HadoopDelegationTokenManager
-import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens
+import org.apache.spark.ui.UIUtils
import org.apache.spark.util.ThreadUtils
/**
- * The following methods are primarily meant to make sure long-running apps like Spark
- * Streaming apps can run without interruption while accessing secured services. The
- * scheduleLoginFromKeytab method is called on the AM to get the new credentials.
- * This method wakes up a thread that logs into the KDC
- * once 75% of the renewal interval of the original credentials used for the container
- * has elapsed. It then obtains new credentials and writes them to HDFS in a
- * pre-specified location - the prefix of which is specified in the sparkConf by
- * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc.
- * - each update goes to a new file, with a monotonically increasing suffix), also the
- * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater.
- * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed.
+ * A manager tasked with periodically updating delegation tokens needed by the application.
*
- * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is
- * called once 80% of the validity of the original credentials has elapsed. At that time the
- * executor finds the credentials file with the latest timestamp and checks if it has read those
- * credentials before (by keeping track of the suffix of the last file it read). If a new file has
- * appeared, it will read the credentials and update the currently running UGI with it. This
- * process happens again once 80% of the validity of this has expired.
+ * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run
+ * without interruption while accessing secured services. It periodically logs in to the KDC with
+ * user-provided credentials, and contacts all the configured secure services to obtain delegation
+ * tokens to be distributed to the rest of the application.
+ *
+ * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API
+ * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is
+ * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet
+ * needed. The check period can be overridden in the configuration.
+ *
+ * New delegation tokens are created once 75% of the renewal interval of the original tokens has
+ * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM.
+ * The driver is tasked with distributing the tokens to other processes that might need them.
*/
private[yarn] class AMCredentialRenewer(
sparkConf: SparkConf,
- hadoopConf: Configuration,
- credentialManager: YARNHadoopDelegationTokenManager) extends Logging {
+ hadoopConf: Configuration) extends Logging {
- private var lastCredentialsFileSuffix = 0
+ private val principal = sparkConf.get(PRINCIPAL).get
+ private val keytab = sparkConf.get(KEYTAB).get
+ private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf)
- private val credentialRenewerThread: ScheduledExecutorService =
+ private val renewalExecutor: ScheduledExecutorService =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread")
- private val hadoopUtil = SparkHadoopUtil.get
+ private val driverRef = new AtomicReference[RpcEndpointRef]()
- private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
- private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION)
- private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT)
- private val freshHadoopConf =
- hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme)
+ private val renewalTask = new Runnable() {
+ override def run(): Unit = {
+ updateTokensTask()
+ }
+ }
- @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME)
+ def setDriverRef(ref: RpcEndpointRef): Unit = {
+ driverRef.set(ref)
+ }
/**
- * Schedule a login from the keytab and principal set using the --principal and --keytab
- * arguments to spark-submit. This login happens only when the credentials of the current user
- * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from
- * SparkConf to do the login. This method is a no-op in non-YARN mode.
+ * Start the token renewer. Upon start, the renewer will:
*
+ * - log in the configured user, and set up a task to keep that user's ticket renewed
+ * - obtain delegation tokens from all available providers
+ * - schedule a periodic task to update the tokens when needed.
+ *
+ * @return The newly logged in user.
*/
- private[spark] def scheduleLoginFromKeytab(): Unit = {
- val principal = sparkConf.get(PRINCIPAL).get
- val keytab = sparkConf.get(KEYTAB).get
-
- /**
- * Schedule re-login and creation of new credentials. If credentials have already expired, this
- * method will synchronously create new ones.
- */
- def scheduleRenewal(runnable: Runnable): Unit = {
- // Run now!
- val remainingTime = timeOfNextRenewal - System.currentTimeMillis()
- if (remainingTime <= 0) {
- logInfo("Credentials have expired, creating new ones now.")
- runnable.run()
- } else {
- logInfo(s"Scheduling login from keytab in $remainingTime millis.")
- credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS)
+ def start(): UserGroupInformation = {
+ val originalCreds = UserGroupInformation.getCurrentUser().getCredentials()
+ val ugi = doLogin()
+
+ val tgtRenewalTask = new Runnable() {
+ override def run(): Unit = {
+ ugi.checkTGTAndReloginFromKeytab()
}
}
+ val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD)
+ renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod,
+ TimeUnit.SECONDS)
- // This thread periodically runs on the AM to update the credentials on HDFS.
- val credentialRenewerRunnable =
- new Runnable {
- override def run(): Unit = {
- try {
- writeNewCredentialsToHDFS(principal, keytab)
- cleanupOldFiles()
- } catch {
- case e: Exception =>
- // Log the error and try to write new tokens back in an hour
- logWarning("Failed to write out new credentials to HDFS, will try again in an " +
- "hour! If this happens too often tasks will fail.", e)
- credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS)
- return
- }
- scheduleRenewal(this)
- }
- }
- // Schedule update of credentials. This handles the case of updating the credentials right now
- // as well, since the renewal interval will be 0, and the thread will get scheduled
- // immediately.
- scheduleRenewal(credentialRenewerRunnable)
+ val creds = obtainTokensAndScheduleRenewal(ugi)
+ ugi.addCredentials(creds)
+
+ // Transfer the original user's tokens to the new user, since that's needed to connect to
+ // YARN. Explicitly avoid overwriting tokens that already exist in the current user's
+ // credentials, since those were freshly obtained above (see SPARK-23361).
+ val existing = ugi.getCredentials()
+ existing.mergeAll(originalCreds)
+ ugi.addCredentials(existing)
+
+ ugi
+ }
+
+ def stop(): Unit = {
+ renewalExecutor.shutdown()
+ }
+
+ private def scheduleRenewal(delay: Long): Unit = {
+ val _delay = math.max(0, delay)
+ logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.")
+ renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS)
}
- // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At
- // least numFilesToKeep files are kept for safety
- private def cleanupOldFiles(): Unit = {
- import scala.concurrent.duration._
+ /**
+ * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself
+ * to fetch the next set of tokens when needed.
+ */
+ private def updateTokensTask(): Unit = {
try {
- val remoteFs = FileSystem.get(freshHadoopConf)
- val credentialsPath = new Path(credentialsFile)
- val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis
- hadoopUtil.listFilesSorted(
- remoteFs, credentialsPath.getParent,
- credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
- .dropRight(numFilesToKeep)
- .takeWhile(_.getModificationTime < thresholdTime)
- .foreach(x => remoteFs.delete(x.getPath, true))
+ val freshUGI = doLogin()
+ val creds = obtainTokensAndScheduleRenewal(freshUGI)
+ val tokens = SparkHadoopUtil.get.serialize(creds)
+
+ val driver = driverRef.get()
+ if (driver != null) {
+ logInfo("Updating delegation tokens.")
+ driver.send(UpdateDelegationTokens(tokens))
+ } else {
+ // This shouldn't really happen, since the driver should register way before tokens expire
+ // (or the AM should time out the application).
+ logWarning("Delegation tokens close to expiration but no driver has registered yet.")
+ SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf)
+ }
} catch {
- // Such errors are not fatal, so don't throw. Make sure they are logged though
case e: Exception =>
- logWarning("Error while attempting to cleanup old credentials. If you are seeing many " +
- "such warnings there may be an issue with your HDFS cluster.", e)
+ val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT))
+ logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" +
+ " If this happens too often tasks will fail.", e)
+ scheduleRenewal(delay)
}
}
- private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = {
- // Keytab is copied by YARN to the working directory of the AM, so full path is
- // not needed.
-
- // HACK:
- // HDFS will not issue new delegation tokens, if the Credentials object
- // passed in already has tokens for that FS even if the tokens are expired (it really only
- // checks if there are tokens for the service, and not if they are valid). So the only real
- // way to get new tokens is to make sure a different Credentials object is used each time to
- // get new tokens and then the new tokens are copied over the current user's Credentials.
- // So:
- // - we login as a different user and get the UGI
- // - use that UGI to get the tokens (see doAs block below)
- // - copy the tokens over to the current user's credentials (this will overwrite the tokens
- // in the current user's Credentials object for this FS).
- // The login to KDC happens each time new tokens are required, but this is rare enough to not
- // have to worry about (like once every day or so). This makes this code clearer than having
- // to login and then relogin every time (the HDFS API may not relogin since we don't use this
- // UGI directly for HDFS communication.
- logInfo(s"Attempting to login to KDC using principal: $principal")
- val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab)
- logInfo("Successfully logged into KDC.")
- val tempCreds = keytabLoggedInUGI.getCredentials
- val credentialsPath = new Path(credentialsFile)
- val dst = credentialsPath.getParent
- var nearestNextRenewalTime = Long.MaxValue
- keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] {
- // Get a copy of the credentials
- override def run(): Void = {
- nearestNextRenewalTime = credentialManager.obtainDelegationTokens(
- freshHadoopConf,
- tempCreds)
- null
+ /**
+ * Obtain new delegation tokens from the available providers. Schedules a new task to fetch
+ * new tokens before the new set expires.
+ *
+ * @return Credentials containing the new tokens.
+ */
+ private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = {
+ ugi.doAs(new PrivilegedExceptionAction[Credentials]() {
+ override def run(): Credentials = {
+ val creds = new Credentials()
+ val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds)
+
+ val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) -
+ System.currentTimeMillis()
+ scheduleRenewal(timeToWait)
+ creds
}
})
-
- val currTime = System.currentTimeMillis()
- val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) {
- // If next renewal time is earlier than current time, we set next renewal time to current
- // time, this will trigger next renewal immediately. Also set next update time to current
- // time. There still has a gap between token renewal and update will potentially introduce
- // issue.
- logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " +
- s"current time ($currTime), which is unexpected, please check your credential renewal " +
- "related configurations in the target services.")
- timeOfNextRenewal = currTime
- currTime
- } else {
- // Next valid renewal time is about 75% of credential renewal time, and update time is
- // slightly later than valid renewal time (80% of renewal time).
- timeOfNextRenewal =
- SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75)
- SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8)
- }
-
- // Add the temp credentials back to the original ones.
- UserGroupInformation.getCurrentUser.addCredentials(tempCreds)
- val remoteFs = FileSystem.get(freshHadoopConf)
- // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM
- // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file
- // and update the lastCredentialsFileSuffix.
- if (lastCredentialsFileSuffix == 0) {
- hadoopUtil.listFilesSorted(
- remoteFs, credentialsPath.getParent,
- credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
- .lastOption.foreach { status =>
- lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath)
- }
- }
- val nextSuffix = lastCredentialsFileSuffix + 1
-
- val tokenPathStr =
- credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM +
- timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM +
- nextSuffix
- val tokenPath = new Path(tokenPathStr)
- val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
-
- logInfo("Writing out delegation tokens to " + tempTokenPath.toString)
- val credentials = UserGroupInformation.getCurrentUser.getCredentials
- credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf)
- logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr")
- remoteFs.rename(tempTokenPath, tokenPath)
- logInfo("Delegation token file rename complete.")
- lastCredentialsFileSuffix = nextSuffix
}
- def stop(): Unit = {
- credentialRenewerThread.shutdown()
+ private def doLogin(): UserGroupInformation = {
+ logInfo(s"Attempting to login to KDC using principal: $principal")
+ val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab)
+ logInfo("Successfully logged into KDC.")
+ ugi
}
+
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
deleted file mode 100644
index fe173dffc22a8..0000000000000
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.yarn.security
-
-import java.util.concurrent.{Executors, TimeUnit}
-
-import scala.util.control.NonFatal
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.deploy.yarn.config._
-import org.apache.spark.internal.Logging
-import org.apache.spark.util.{ThreadUtils, Utils}
-
-private[spark] class CredentialUpdater(
- sparkConf: SparkConf,
- hadoopConf: Configuration,
- credentialManager: YARNHadoopDelegationTokenManager) extends Logging {
-
- @volatile private var lastCredentialsFileSuffix = 0
-
- private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH)
- private val freshHadoopConf =
- SparkHadoopUtil.get.getConfBypassingFSCache(
- hadoopConf, new Path(credentialsFile).toUri.getScheme)
-
- private val credentialUpdater =
- Executors.newSingleThreadScheduledExecutor(
- ThreadUtils.namedThreadFactory("Credential Refresh Thread"))
-
- // This thread wakes up and picks up new credentials from HDFS, if any.
- private val credentialUpdaterRunnable =
- new Runnable {
- override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired())
- }
-
- /** Start the credential updater task */
- def start(): Unit = {
- val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME)
- val remainingTime = startTime - System.currentTimeMillis()
- if (remainingTime <= 0) {
- credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES)
- } else {
- logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.")
- credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS)
- }
- }
-
- private def updateCredentialsIfRequired(): Unit = {
- val timeToNextUpdate = try {
- val credentialsFilePath = new Path(credentialsFile)
- val remoteFs = FileSystem.get(freshHadoopConf)
- SparkHadoopUtil.get.listFilesSorted(
- remoteFs, credentialsFilePath.getParent,
- credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION)
- .lastOption.map { credentialsStatus =>
- val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath)
- if (suffix > lastCredentialsFileSuffix) {
- logInfo("Reading new credentials from " + credentialsStatus.getPath)
- val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath)
- lastCredentialsFileSuffix = suffix
- UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
- logInfo("Credentials updated from credentials file.")
-
- val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath)
- - System.currentTimeMillis())
- if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime
- } else {
- // If current credential file is older than expected, sleep 1 hour and check again.
- TimeUnit.HOURS.toMillis(1)
- }
- }.getOrElse {
- // Wait for 1 minute to check again if there's no credential file currently
- TimeUnit.MINUTES.toMillis(1)
- }
- } catch {
- // Since the file may get deleted while we are reading it, catch the Exception and come
- // back in an hour to try again
- case NonFatal(e) =>
- logWarning("Error while trying to update credentials, will try again in 1 hour", e)
- TimeUnit.HOURS.toMillis(1)
- }
-
- logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.")
- credentialUpdater.schedule(
- credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS)
- }
-
- private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = {
- val stream = remoteFs.open(tokenPath)
- try {
- val newCredentials = new Credentials()
- newCredentials.readTokenStorageStream(stream)
- newCredentials
- } finally {
- stream.close()
- }
- }
-
- private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = {
- val name = credentialsPath.getName
- val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM)
- val slice = name.substring(0, index)
- val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM)
- name.substring(last2index + 1, index).toLong
- }
-
- def stop(): Unit = {
- credentialUpdater.shutdown()
- }
-
-}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala
index 163cfb4eb8624..26a2e5d730218 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala
@@ -22,11 +22,11 @@ import java.util.ServiceLoader
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.security.Credentials
import org.apache.spark.SparkConf
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
+import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
@@ -37,14 +37,17 @@ import org.apache.spark.util.Utils
*/
private[yarn] class YARNHadoopDelegationTokenManager(
sparkConf: SparkConf,
- hadoopConf: Configuration,
- fileSystems: Configuration => Set[FileSystem]) extends Logging {
+ hadoopConf: Configuration) extends Logging {
- private val delegationTokenManager =
- new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems)
+ private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf,
+ conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
// public for testing
val credentialProviders = getCredentialProviders
+ if (credentialProviders.nonEmpty) {
+ logDebug("Using the following YARN-specific credential providers: " +
+ s"${credentialProviders.keys.mkString(", ")}.")
+ }
/**
* Writes delegation tokens to creds. Delegation tokens are fetched from all registered
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 0c6206eebe41d..f1a8df00f9c5b 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.yarn.api.records.YarnApplicationState
import org.apache.spark.{SparkContext, SparkException}
-import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkAppHandle
@@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend(
super.start()
waitForApplication()
- // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver
- // reads the credentials from HDFS, just like the executors and updates its own credentials
- // cache.
- if (conf.contains("spark.yarn.credentials.file")) {
- YarnSparkHadoopUtil.startCredentialUpdater(conf)
- }
monitorThread = asyncMonitorApplication()
monitorThread.start()
}
@@ -81,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend(
val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL)
assert(client != null && appId.isDefined, "Application has not been submitted yet!")
- val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true,
- interval = monitorInterval) // blocking
+ val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get,
+ returnOnRunning = true, interval = monitorInterval)
if (state == YarnApplicationState.FINISHED ||
- state == YarnApplicationState.FAILED ||
- state == YarnApplicationState.KILLED) {
- throw new SparkException("Yarn application has already ended! " +
- "It might have been killed or unable to launch application master.")
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ val genericMessage = "The YARN application has already ended! " +
+ "It might have been killed or the Application Master may have failed to start. " +
+ "Check the YARN application logs for more details."
+ val exceptionMsg = diags match {
+ case Some(msg) =>
+ logError(genericMessage)
+ msg
+
+ case None =>
+ genericMessage
+ }
+ throw new SparkException(exceptionMsg)
}
if (state == YarnApplicationState.RUNNING) {
logInfo(s"Application ${appId.get} has started running.")
@@ -106,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend(
override def run() {
try {
- val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false)
- logError(s"Yarn application has already exited with state $state!")
+ val YarnAppReport(_, state, diags) =
+ client.monitorApplication(appId.get, logApplicationReport = true)
+ logError(s"YARN application has exited unexpectedly with state $state! " +
+ "Check the YARN application logs for more details.")
+ diags.foreach { err =>
+ logError(s"Diagnostics message: $err")
+ }
allowInterrupt = false
sc.stop()
} catch {
@@ -130,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend(
private def asyncMonitorApplication(): MonitorThread = {
assert(client != null && appId.isDefined, "Application has not been submitted yet!")
val t = new MonitorThread
- t.setName("Yarn application state monitor")
+ t.setName("YARN application state monitor")
t.setDaemon(true)
t
}
@@ -153,7 +162,6 @@ private[spark] class YarnClientSchedulerBackend(
client.reportLauncherState(SparkAppHandle.State.FINISHED)
super.stop()
- YarnSparkHadoopUtil.stopCredentialUpdater()
client.stop()
logInfo("Stopped")
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index bb615c36cd97f..63bea3e7a5003 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
import org.apache.spark.SparkContext
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler._
@@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend(
/** Scheduler extension services. */
private val services: SchedulerExtensionServices = new SchedulerExtensionServices()
+
/**
* Bind to YARN. This *must* be done before calling [[start()]].
*
@@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend(
logWarning(s"Requesting driver to remove executor $executorId for reason $reason")
driverEndpoint.send(r)
}
- }
+ case u @ UpdateDelegationTokens(tokens) =>
+ // Add the tokens to the current user and send a message to the scheduler so that it
+ // notifies all registered executors of the new tokens.
+ SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf)
+ driverEndpoint.send(u)
+ }
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case r: RequestExecutors =>
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index cb1e3c5268510..525abb6f2b350 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0)
}
handler.updateResourceRequests()
- handler.processCompletedContainers(statuses.toSeq)
+ handler.processCompletedContainers(statuses)
handler.getNumExecutorsRunning should be (0)
handler.getPendingAllocate.size should be (1)
}
+ test("kill same executor multiple times") {
+ val handler = createAllocator(2)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (2)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+ handler.getNumExecutorsRunning should be (2)
+ handler.getPendingAllocate.size should be (0)
+
+ val executorToKill = handler.executorIdToContainer.keys.head
+ handler.killExecutor(executorToKill)
+ handler.getNumExecutorsRunning should be (1)
+ handler.killExecutor(executorToKill)
+ handler.killExecutor(executorToKill)
+ handler.killExecutor(executorToKill)
+ handler.getNumExecutorsRunning should be (1)
+ handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty)
+ handler.updateResourceRequests()
+ handler.getPendingAllocate.size should be (1)
+ }
+
+ test("process same completed container multiple times") {
+ val handler = createAllocator(2)
+ handler.updateResourceRequests()
+ handler.getNumExecutorsRunning should be (0)
+ handler.getPendingAllocate.size should be (2)
+
+ val container1 = createContainer("host1")
+ val container2 = createContainer("host2")
+ handler.handleAllocatedContainers(Array(container1, container2))
+ handler.getNumExecutorsRunning should be (2)
+ handler.getPendingAllocate.size should be (0)
+
+ val statuses = Seq(container1, container1, container2).map { c =>
+ ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0)
+ }
+ handler.processCompletedContainers(statuses)
+ handler.getNumExecutorsRunning should be (0)
+
+ }
+
test("lost executor removed from backend") {
val handler = createAllocator(4)
handler.updateResourceRequests()
@@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1)
}
handler.updateResourceRequests()
- handler.processCompletedContainers(statuses.toSeq)
+ handler.processCompletedContainers(statuses)
handler.updateResourceRequests()
handler.getNumExecutorsRunning should be (0)
handler.getPendingAllocate.size should be (2)
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 5003326b440bf..3b78b88de778d 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -114,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
))
}
- test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") {
+ test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") {
+ // Create a custom hadoop config file, to make sure it's contents are propagated to the driver.
+ val customConf = Utils.createTempDir()
+ val coreSite = """
+ |
+ |
+ | spark.test.key
+ | testvalue
+ |
+ |
+ |""".stripMargin
+ Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8)
+
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(false,
mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass),
- appArgs = Seq("key=value", result.getAbsolutePath()),
- extraConf = Map("spark.hadoop.key" -> "value"))
+ appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()),
+ extraConf = Map("spark.hadoop.key" -> "value"),
+ extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath()))
checkResult(finalState, result)
}
@@ -212,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
finalState should be (SparkAppHandle.State.FAILED)
}
+ test("executor env overwrite AM env in client mode") {
+ testExecutorEnv(true)
+ }
+
+ test("executor env overwrite AM env in cluster mode") {
+ testExecutorEnv(false)
+ }
+
private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
@@ -244,22 +265,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
// needed locations.
val sparkHome = sys.props("spark.test.home")
val pythonPath = Seq(
- s"$sparkHome/python/lib/py4j-0.10.6-src.zip",
+ s"$sparkHome/python/lib/py4j-0.10.7-src.zip",
s"$sparkHome/python")
val extraEnvVars = Map(
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
"PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv
- val moduleDir =
- if (clientMode) {
- // In client-mode, .py files added with --py-files are not visible in the driver.
- // This is something that the launcher library would have to handle.
- tempDir
- } else {
- val subdir = new File(tempDir, "pyModules")
- subdir.mkdir()
- subdir
- }
+ val moduleDir = {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
val pyModule = new File(moduleDir, "mod1.py")
Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8)
@@ -292,6 +308,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
checkResult(finalState, executorResult, "OVERRIDDEN")
}
+ private def testExecutorEnv(clientMode: Boolean): Unit = {
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass),
+ appArgs = Seq(result.getAbsolutePath),
+ extraConf = Map(
+ "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val",
+ "spark.executorEnv.TEST_ENV" -> "executor_val"
+ )
+ )
+ checkResult(finalState, result, "true")
+ }
}
private[spark] class SaveExecutorInfo extends SparkListener {
@@ -319,13 +346,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers {
private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers {
def main(args: Array[String]): Unit = {
- if (args.length != 2) {
+ if (args.length < 2) {
// scalastyle:off println
System.err.println(
s"""
|Invalid command line: ${args.mkString(" ")}
|
- |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file]
+ |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file]
""".stripMargin)
// scalastyle:on println
System.exit(1)
@@ -335,11 +362,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc
.set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
.setAppName("yarn test using SparkHadoopUtil's conf"))
- val kv = args(0).split("=")
- val status = new File(args(1))
+ val kvs = args.take(args.length - 1).map { kv =>
+ val parsed = kv.split("=")
+ (parsed(0), parsed(1))
+ }
+ val status = new File(args.last)
var result = "failure"
try {
- SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1))
+ kvs.foreach { case (k, v) =>
+ SparkHadoopUtil.get.conf.get(k) should be (v)
+ }
result = "success"
} finally {
Files.write(result, status, StandardCharsets.UTF_8)
@@ -508,3 +540,20 @@ private object SparkContextTimeoutApp {
}
}
+
+private object ExecutorEnvTestApp {
+
+ def main(args: Array[String]): Unit = {
+ val status = args(0)
+ val sparkConf = new SparkConf()
+ val sc = new SparkContext(sparkConf)
+ val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap
+ val result = sparkConf.getExecutorEnv.forall { case (k, v) =>
+ executorEnvs.get(k).contains(v)
+ }
+
+ Files.write(result.toString, new File(status), StandardCharsets.UTF_8)
+ sc.stop()
+ }
+
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index f21353aa007c8..61c0c43f7c04f 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -21,7 +21,8 @@ import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import com.google.common.io.{ByteStreams, Files}
-import org.apache.hadoop.io.Text
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.yarn.api.records.ApplicationAccessType
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.scalatest.Matchers
@@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
}
+ test("SPARK-24149: retrieve all namenodes from HDFS") {
+ val sparkConf = new SparkConf()
+ val basicFederationConf = new Configuration()
+ basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020")
+ basicFederationConf.set("dfs.nameservices", "ns1,ns2")
+ basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020")
+ basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021")
+ val basicFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf),
+ new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf))
+ val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, basicFederationConf)
+ basicFederationResult should be (basicFederationExpected)
+
+ // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them
+ val viewFsConf = new Configuration()
+ viewFsConf.addResource(basicFederationConf)
+ viewFsConf.set("fs.defaultFS", "viewfs://clusterX/")
+ viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/")
+ val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf))
+ YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected)
+
+ // invalid config should not throw NullPointerException
+ val invalidFederationConf = new Configuration()
+ invalidFederationConf.addResource(basicFederationConf)
+ invalidFederationConf.unset("dfs.namenode.rpc-address.ns2")
+ val invalidFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf))
+ val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, invalidFederationConf)
+ invalidFederationResult should be (invalidFederationExpected)
+
+ // no namespaces defined, ie. old case
+ val noFederationConf = new Configuration()
+ noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020")
+ val noFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(noFederationConf))
+ val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf)
+ noFederationResult should be (noFederationExpected)
+
+ // federation and HA enabled
+ val federationAndHAConf = new Configuration()
+ federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA")
+ federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA")
+ federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2")
+ federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023")
+ federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA",
+ "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider")
+ federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA",
+ "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider")
+
+ val federationAndHAExpected = Set(
+ new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf),
+ new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf))
+ val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, federationAndHAConf)
+ federationAndHAResult should be (federationAndHAExpected)
+ }
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
index 3c7cdc0f1dab8..9fa749b14c98c 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala
@@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials
import org.scalatest.Matchers
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil
class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers {
private var credentialManager: YARNHadoopDelegationTokenManager = null
@@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers
}
test("Correctly loads credential providers") {
- credentialManager = new YARNHadoopDelegationTokenManager(
- sparkConf,
- hadoopConf,
- conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf))
-
+ credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf)
credentialManager.credentialProviders.get("yarn-test") should not be (None)
}
}
diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh
index bac154e10ae62..bf3da18c3706e 100755
--- a/sbin/spark-config.sh
+++ b/sbin/spark-config.sh
@@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}"
# Add the PySpark classes to the PYTHONPATH:
if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then
export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}"
- export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}"
+ export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}"
export PYSPARK_PYTHONPATH_SET=1
fi
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index e2fa5754afaee..e65e3aafe5b5b 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -229,7 +229,7 @@ This file is divided into 3 sections:
extractOpt
- Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter
+ Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter
is slower.
diff --git a/sql/README.md b/sql/README.md
index fe1d352050c09..70cc7c637b58d 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe
Spark SQL is broken up into four subprojects:
- Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions.
- Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files.
- - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs.
+ - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs.
- HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server.
Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`.
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 5fa75fe348e68..3fe00eefde7d8 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -398,7 +398,7 @@ hintStatement
;
fromClause
- : FROM relation (',' relation)* lateralView*
+ : FROM relation (',' relation)* lateralView* pivotClause?
;
aggregation
@@ -413,6 +413,10 @@ groupingSet
| expression
;
+pivotClause
+ : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')'
+ ;
+
lateralView
: LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
;
@@ -588,6 +592,7 @@ primaryExpression
| identifier #columnReference
| base=primaryExpression '.' fieldName=identifier #dereference
| '(' expression ')' #parenthesizedExpression
+ | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract
;
constant
@@ -725,7 +730,7 @@ nonReserved
| ADD
| OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER
| MAP | ARRAY | STRUCT
- | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
+ | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER
| DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
| EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS
| GROUPING | CUBE | ROLLUP
@@ -735,6 +740,7 @@ nonReserved
| VIEW | REPLACE
| IF
| POSITION
+ | EXTRACT
| NO | DATA
| START | TRANSACTION | COMMIT | ROLLBACK | IGNORE
| SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION
@@ -745,7 +751,7 @@ nonReserved
| REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
| ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH
| ASC | DESC | LIMIT | RENAME | SETS
- | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE
+ | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE
| DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE
| NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE
| AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN
@@ -760,6 +766,7 @@ FROM: 'FROM';
ADD: 'ADD';
AS: 'AS';
ALL: 'ALL';
+ANY: 'ANY';
DISTINCT: 'DISTINCT';
WHERE: 'WHERE';
GROUP: 'GROUP';
@@ -805,6 +812,7 @@ RIGHT: 'RIGHT';
FULL: 'FULL';
NATURAL: 'NATURAL';
ON: 'ON';
+PIVOT: 'PIVOT';
LATERAL: 'LATERAL';
WINDOW: 'WINDOW';
OVER: 'OVER';
@@ -872,6 +880,7 @@ TRAILING: 'TRAILING';
IF: 'IF';
POSITION: 'POSITION';
+EXTRACT: 'EXTRACT';
EQ : '=' | '==';
NSEQ: '<=>';
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java
new file mode 100644
index 0000000000000..05879902a4ed9
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+/**
+ * Contains all the Utils methods used in the masking expressions.
+ */
+public class MaskExpressionsUtils {
+ static final int UNMASKED_VAL = -1;
+
+ /**
+ * Returns the masking character for {@param c} or {@param c} is it should not be masked.
+ * @param c the character to transform
+ * @param maskedUpperChar the character to use instead of a uppercase letter
+ * @param maskedLowerChar the character to use instead of a lowercase letter
+ * @param maskedDigitChar the character to use instead of a digit
+ * @param maskedOtherChar the character to use instead of a any other character
+ * @return masking character for {@param c}
+ */
+ public static int transformChar(
+ final int c,
+ int maskedUpperChar,
+ int maskedLowerChar,
+ int maskedDigitChar,
+ int maskedOtherChar) {
+ switch(Character.getType(c)) {
+ case Character.UPPERCASE_LETTER:
+ if(maskedUpperChar != UNMASKED_VAL) {
+ return maskedUpperChar;
+ }
+ break;
+
+ case Character.LOWERCASE_LETTER:
+ if(maskedLowerChar != UNMASKED_VAL) {
+ return maskedLowerChar;
+ }
+ break;
+
+ case Character.DECIMAL_DIGIT_NUMBER:
+ if(maskedDigitChar != UNMASKED_VAL) {
+ return maskedDigitChar;
+ }
+ break;
+
+ default:
+ if(maskedOtherChar != UNMASKED_VAL) {
+ return maskedOtherChar;
+ }
+ break;
+ }
+
+ return c;
+ }
+
+ /**
+ * Returns the replacement char to use according to the {@param rep} specified by the user and
+ * the {@param def} default.
+ */
+ public static int getReplacementChar(String rep, int def) {
+ if (rep != null && rep.length() > 0) {
+ return rep.codePointAt(0);
+ }
+ return def;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index d18542b188f71..4dd2b7365652a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -27,6 +27,7 @@
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -55,9 +56,19 @@
public final class UnsafeArrayData extends ArrayData {
public static int calculateHeaderPortionInBytes(int numFields) {
+ return (int)calculateHeaderPortionInBytes((long)numFields);
+ }
+
+ public static long calculateHeaderPortionInBytes(long numFields) {
return 8 + ((numFields + 63)/ 64) * 8;
}
+ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) {
+ long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
+ return size;
+ }
+
private Object baseObject;
private long baseOffset;
@@ -72,7 +83,7 @@ public static int calculateHeaderPortionInBytes(int numFields) {
private long elementOffset;
private long getElementOffset(int ordinal, int elementSize) {
- return elementOffset + ordinal * elementSize;
+ return elementOffset + ordinal * (long)elementSize;
}
public Object getBaseObject() { return baseObject; }
@@ -230,7 +241,8 @@ public UTF8String getUTF8String(int ordinal) {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
- return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
+ MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size);
+ return new UTF8String(mb);
}
@Override
@@ -402,7 +414,7 @@ public byte[] toByteArray() {
public short[] toShortArray() {
short[] values = new short[numElements];
Platform.copyMemory(
- baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2);
+ baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L);
return values;
}
@@ -410,7 +422,7 @@ public short[] toShortArray() {
public int[] toIntArray() {
int[] values = new int[numElements];
Platform.copyMemory(
- baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4);
+ baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L);
return values;
}
@@ -418,7 +430,7 @@ public int[] toIntArray() {
public long[] toLongArray() {
long[] values = new long[numElements];
Platform.copyMemory(
- baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8);
+ baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L);
return values;
}
@@ -426,7 +438,7 @@ public long[] toLongArray() {
public float[] toFloatArray() {
float[] values = new float[numElements];
Platform.copyMemory(
- baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4);
+ baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L);
return values;
}
@@ -434,14 +446,14 @@ public float[] toFloatArray() {
public double[] toDoubleArray() {
double[] values = new double[numElements];
Platform.copyMemory(
- baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8);
+ baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L);
return values;
}
private static UnsafeArrayData fromPrimitiveArray(
Object arr, int offset, int length, int elementSize) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
- final long valueRegionInBytes = elementSize * length;
+ final long valueRegionInBytes = (long)elementSize * length;
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 71c086029cc5b..469b0e60cc9a2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -37,6 +37,7 @@
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -61,6 +62,8 @@
*/
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {
+ public static final int WORD_SIZE = 8;
+
//////////////////////////////////////////////////////////////////////////////
// Static methods
//////////////////////////////////////////////////////////////////////////////
@@ -414,7 +417,8 @@ public UTF8String getUTF8String(int ordinal) {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
- return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
+ MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size);
+ return new UTF8String(mb);
}
@Override
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java
index 905e6820ce6e2..c823de4810f2b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java
@@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
@Override
public UnsafeRow appendRow(Object kbase, long koff, int klen,
Object vbase, long voff, int vlen) {
- final long recordLength = 8 + klen + vlen + 8;
+ final long recordLength = 8L + klen + vlen + 8;
// if run out of max supported rows or page size, return null
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
return null;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java
index f37ef83ad92b4..8e9c0a2e9dc81 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java
@@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.types.UTF8String;
// scalastyle: off
/**
@@ -71,13 +72,13 @@ public static long hashLong(long input, long seed) {
return fmix(hash);
}
- public long hashUnsafeWords(Object base, long offset, int length) {
- return hashUnsafeWords(base, offset, length, seed);
+ public long hashUnsafeWordsBlock(MemoryBlock mb) {
+ return hashUnsafeWordsBlock(mb, seed);
}
- public static long hashUnsafeWords(Object base, long offset, int length, long seed) {
- assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)";
- long hash = hashBytesByWords(base, offset, length, seed);
+ public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) {
+ assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)";
+ long hash = hashBytesByWordsBlock(mb, seed);
return fmix(hash);
}
@@ -85,26 +86,36 @@ public long hashUnsafeBytes(Object base, long offset, int length) {
return hashUnsafeBytes(base, offset, length, seed);
}
- public static long hashUnsafeBytes(Object base, long offset, int length, long seed) {
+ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) {
+ long offset = 0;
+ long length = mb.size();
assert (length >= 0) : "lengthInBytes cannot be negative";
- long hash = hashBytesByWords(base, offset, length, seed);
+ long hash = hashBytesByWordsBlock(mb, seed);
long end = offset + length;
offset += length & -8;
if (offset + 4L <= end) {
- hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1;
+ hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1;
hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3;
offset += 4L;
}
while (offset < end) {
- hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5;
+ hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5;
hash = Long.rotateLeft(hash, 11) * PRIME64_1;
offset++;
}
return fmix(hash);
}
+ public static long hashUTF8String(UTF8String str, long seed) {
+ return hashUnsafeBytesBlock(str.getMemoryBlock(), seed);
+ }
+
+ public static long hashUnsafeBytes(Object base, long offset, int length, long seed) {
+ return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed);
+ }
+
private static long fmix(long hash) {
hash ^= hash >>> 33;
hash *= PRIME64_2;
@@ -114,30 +125,31 @@ private static long fmix(long hash) {
return hash;
}
- private static long hashBytesByWords(Object base, long offset, int length, long seed) {
- long end = offset + length;
+ private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) {
+ long offset = 0;
+ long length = mb.size();
long hash;
if (length >= 32) {
- long limit = end - 32;
+ long limit = length - 32;
long v1 = seed + PRIME64_1 + PRIME64_2;
long v2 = seed + PRIME64_2;
long v3 = seed;
long v4 = seed - PRIME64_1;
do {
- v1 += Platform.getLong(base, offset) * PRIME64_2;
+ v1 += mb.getLong(offset) * PRIME64_2;
v1 = Long.rotateLeft(v1, 31);
v1 *= PRIME64_1;
- v2 += Platform.getLong(base, offset + 8) * PRIME64_2;
+ v2 += mb.getLong(offset + 8) * PRIME64_2;
v2 = Long.rotateLeft(v2, 31);
v2 *= PRIME64_1;
- v3 += Platform.getLong(base, offset + 16) * PRIME64_2;
+ v3 += mb.getLong(offset + 16) * PRIME64_2;
v3 = Long.rotateLeft(v3, 31);
v3 *= PRIME64_1;
- v4 += Platform.getLong(base, offset + 24) * PRIME64_2;
+ v4 += mb.getLong(offset + 24) * PRIME64_2;
v4 = Long.rotateLeft(v4, 31);
v4 *= PRIME64_1;
@@ -178,9 +190,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long
hash += length;
- long limit = end - 8;
+ long limit = length - 8;
while (offset <= limit) {
- long k1 = Platform.getLong(base, offset);
+ long k1 = mb.getLong(offset);
hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1;
hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4;
offset += 8L;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index 259976118c12f..537ef244b7e81 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -30,25 +30,21 @@
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
* and reuse the data buffer.
- *
- * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
- * the size of the result row, after writing a record to the buffer. However, we can skip this step
- * if the fields of row are all fixed-length, as the size of result row is also fixed.
*/
-public class BufferHolder {
+final class BufferHolder {
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
- public byte[] buffer;
- public int cursor = Platform.BYTE_ARRAY_OFFSET;
+ private byte[] buffer;
+ private int cursor = Platform.BYTE_ARRAY_OFFSET;
private final UnsafeRow row;
private final int fixedSize;
- public BufferHolder(UnsafeRow row) {
+ BufferHolder(UnsafeRow row) {
this(row, 64);
}
- public BufferHolder(UnsafeRow row, int initialSize) {
+ BufferHolder(UnsafeRow row, int initialSize) {
int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) {
throw new UnsupportedOperationException(
@@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) {
/**
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
- public void grow(int neededSize) {
+ void grow(int neededSize) {
if (neededSize > ARRAY_MAX - totalSize()) {
throw new UnsupportedOperationException(
"Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
@@ -86,11 +82,23 @@ public void grow(int neededSize) {
}
}
- public void reset() {
+ byte[] getBuffer() {
+ return buffer;
+ }
+
+ int getCursor() {
+ return cursor;
+ }
+
+ void increaseCursor(int val) {
+ cursor += val;
+ }
+
+ void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
}
- public int totalSize() {
+ int totalSize() {
return cursor - Platform.BYTE_ARRAY_OFFSET;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
index f0f66bae245fd..f8000d78cd1b6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java
@@ -19,6 +19,8 @@
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -29,43 +31,34 @@ public class UTF8StringBuilder {
private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;
- private byte[] buffer;
- private int cursor = Platform.BYTE_ARRAY_OFFSET;
+ private ByteArrayMemoryBlock buffer;
+ private int length = 0;
public UTF8StringBuilder() {
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
- this.buffer = new byte[16];
+ this.buffer = new ByteArrayMemoryBlock(16);
}
// Grows the buffer by at least `neededSize`
private void grow(int neededSize) {
- if (neededSize > ARRAY_MAX - totalSize()) {
+ if (neededSize > ARRAY_MAX - length) {
throw new UnsupportedOperationException(
"Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
"exceeds size limitation " + ARRAY_MAX);
}
- final int length = totalSize() + neededSize;
- if (buffer.length < length) {
- int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
- final byte[] tmp = new byte[newLength];
- Platform.copyMemory(
- buffer,
- Platform.BYTE_ARRAY_OFFSET,
- tmp,
- Platform.BYTE_ARRAY_OFFSET,
- totalSize());
+ final int requestedSize = length + neededSize;
+ if (buffer.size() < requestedSize) {
+ int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX;
+ final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength);
+ MemoryBlock.copyMemory(buffer, tmp, length);
buffer = tmp;
}
}
- private int totalSize() {
- return cursor - Platform.BYTE_ARRAY_OFFSET;
- }
-
public void append(UTF8String value) {
grow(value.numBytes());
- value.writeToMemory(buffer, cursor);
- cursor += value.numBytes();
+ value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET);
+ length += value.numBytes();
}
public void append(String value) {
@@ -73,6 +66,6 @@ public void append(String value) {
}
public UTF8String build() {
- return UTF8String.fromBytes(buffer, 0, totalSize());
+ return UTF8String.fromBytes(buffer.getByteArray(), 0, length);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 791e8d80e6cba..a78dd970d23e4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -21,8 +21,6 @@
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
@@ -30,16 +28,14 @@
* A helper class to write data into global row buffer using `UnsafeArrayData` format,
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
*/
-public class UnsafeArrayWriter {
-
- private BufferHolder holder;
-
- // The offset of the global buffer where we start to write this array.
- private int startingOffset;
+public final class UnsafeArrayWriter extends UnsafeWriter {
// The number of elements in this array
private int numElements;
+ // The element size in this array
+ private int elementSize;
+
private int headerInBytes;
private void assertIndexIsValid(int index) {
@@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) {
assert index < numElements : "index (" + index + ") should < " + numElements;
}
- public void initialize(BufferHolder holder, int numElements, int elementSize) {
+ public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) {
+ super(writer.getBufferHolder());
+ this.elementSize = elementSize;
+ }
+
+ public void initialize(int numElements) {
// We need 8 bytes to store numElements in header
this.numElements = numElements;
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
- this.holder = holder;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
// Grows the global buffer ahead for header and fixed size data.
int fixedPartInBytes =
@@ -61,130 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) {
holder.grow(headerInBytes + fixedPartInBytes);
// Write numElements and clear out null bits to header
- Platform.putLong(holder.buffer, startingOffset, numElements);
+ Platform.putLong(getBuffer(), startingOffset, numElements);
for (int i = 8; i < headerInBytes; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
// fill 0 into reminder part of 8-bytes alignment in unsafe array
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
- Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
- }
- holder.cursor += (headerInBytes + fixedPartInBytes);
- }
-
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+ Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0);
}
+ increaseCursor(headerInBytes + fixedPartInBytes);
}
- private long getElementOffset(int ordinal, int elementSize) {
+ private long getElementOffset(int ordinal) {
return startingOffset + headerInBytes + ordinal * elementSize;
}
- public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
- assertIndexIsValid(ordinal);
- final long relativeOffset = currentCursor - startingOffset;
- final long offsetAndSize = (relativeOffset << 32) | (long)size;
-
- write(ordinal, offsetAndSize);
- }
-
private void setNullBit(int ordinal) {
assertIndexIsValid(ordinal);
- BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
+ BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal);
}
- public void setNullBoolean(int ordinal) {
+ public void setNull1Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
+ writeByte(getElementOffset(ordinal), (byte)0);
}
- public void setNullByte(int ordinal) {
+ public void setNull2Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
+ writeShort(getElementOffset(ordinal), (short)0);
}
- public void setNullShort(int ordinal) {
+ public void setNull4Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
+ writeInt(getElementOffset(ordinal), 0);
}
- public void setNullInt(int ordinal) {
+ public void setNull8Bytes(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0);
+ writeLong(getElementOffset(ordinal), 0);
}
- public void setNullLong(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
- }
-
- public void setNullFloat(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
- }
-
- public void setNullDouble(int ordinal) {
- setNullBit(ordinal);
- // put zero into the corresponding field when set null
- Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
- }
-
- public void setNull(int ordinal) { setNullLong(ordinal); }
+ public void setNull(int ordinal) { setNull8Bytes(ordinal); }
public void write(int ordinal, boolean value) {
assertIndexIsValid(ordinal);
- Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeBoolean(getElementOffset(ordinal), value);
}
public void write(int ordinal, byte value) {
assertIndexIsValid(ordinal);
- Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
+ writeByte(getElementOffset(ordinal), value);
}
public void write(int ordinal, short value) {
assertIndexIsValid(ordinal);
- Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
+ writeShort(getElementOffset(ordinal), value);
}
public void write(int ordinal, int value) {
assertIndexIsValid(ordinal);
- Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeInt(getElementOffset(ordinal), value);
}
public void write(int ordinal, long value) {
assertIndexIsValid(ordinal);
- Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeLong(getElementOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
+ writeFloat(getElementOffset(ordinal), value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
assertIndexIsValid(ordinal);
- Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
+ writeDouble(getElementOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
assertIndexIsValid(ordinal);
- if (input.changePrecision(precision, scale)) {
+ if (input != null && input.changePrecision(precision, scale)) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
write(ordinal, input.toUnscaledLong());
} else {
@@ -198,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
+ setOffsetAndSize(ordinal, numBytes);
// move the cursor forward with 8-bytes boundary
- holder.cursor += roundedSize;
+ increaseCursor(roundedSize);
}
} else {
setNull(ordinal);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- final int numBytes = input.length;
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, holder.cursor, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, holder.cursor, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 5d9515c0725da..71c49d8ed0177 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -20,10 +20,7 @@
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
/**
* A helper class to write data into global row buffer using `UnsafeRow` format.
@@ -31,38 +28,67 @@
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
- * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
+ * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the
* `startingOffset` and clear out null bits.
*
* Note that if this is the outermost writer, which means we will always write from the very
* beginning of the global row buffer, we don't need to update `startingOffset` and can just call
* `zeroOutNullBytes` before writing new data.
*/
-public class UnsafeRowWriter {
+public final class UnsafeRowWriter extends UnsafeWriter {
+
+ private final UnsafeRow row;
- private final BufferHolder holder;
- // The offset of the global buffer where we start to write this row.
- private int startingOffset;
private final int nullBitsSize;
private final int fixedSize;
- public UnsafeRowWriter(BufferHolder holder, int numFields) {
- this.holder = holder;
+ public UnsafeRowWriter(int numFields) {
+ this(new UnsafeRow(numFields));
+ }
+
+ public UnsafeRowWriter(int numFields, int initialBufferSize) {
+ this(new UnsafeRow(numFields), initialBufferSize);
+ }
+
+ public UnsafeRowWriter(UnsafeWriter writer, int numFields) {
+ this(null, writer.getBufferHolder(), numFields);
+ }
+
+ private UnsafeRowWriter(UnsafeRow row) {
+ this(row, new BufferHolder(row), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) {
+ this(row, new BufferHolder(row, initialBufferSize), row.numFields());
+ }
+
+ private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) {
+ super(holder);
+ this.row = row;
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
this.fixedSize = nullBitsSize + 8 * numFields;
- this.startingOffset = holder.cursor;
+ this.startingOffset = cursor();
+ }
+
+ /**
+ * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns
+ * the UnsafeRow created at a constructor
+ */
+ public UnsafeRow getRow() {
+ row.setTotalSize(totalSize());
+ return row;
}
/**
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
* bits. This should be called before we write a new nested struct to the row buffer.
*/
- public void reset() {
- this.startingOffset = holder.cursor;
+ public void resetRowWriter() {
+ this.startingOffset = cursor();
// grow the global buffer to make sure it has enough space to write fixed-length data.
- holder.grow(fixedSize);
- holder.cursor += fixedSize;
+ grow(fixedSize);
+ increaseCursor(fixedSize);
zeroOutNullBytes();
}
@@ -72,92 +98,86 @@ public void reset() {
*/
public void zeroOutNullBytes() {
for (int i = 0; i < nullBitsSize; i += 8) {
- Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ Platform.putLong(getBuffer(), startingOffset + i, 0L);
}
}
- private void zeroOutPaddingBytes(int numBytes) {
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
- }
- }
-
- public BufferHolder holder() { return holder; }
-
public boolean isNullAt(int ordinal) {
- return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal);
+ return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal);
}
public void setNullAt(int ordinal) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
+ write(ordinal, 0L);
}
- public long getFieldOffset(int ordinal) {
- return startingOffset + nullBitsSize + 8 * ordinal;
+ @Override
+ public void setNull1Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
+
+ @Override
+ public void setNull2Bytes(int ordinal) {
+ setNullAt(ordinal);
}
- public void setOffsetAndSize(int ordinal, long size) {
- setOffsetAndSize(ordinal, holder.cursor, size);
+ @Override
+ public void setNull4Bytes(int ordinal) {
+ setNullAt(ordinal);
}
- public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
- final long relativeOffset = currentCursor - startingOffset;
- final long fieldOffset = getFieldOffset(ordinal);
- final long offsetAndSize = (relativeOffset << 32) | size;
+ @Override
+ public void setNull8Bytes(int ordinal) {
+ setNullAt(ordinal);
+ }
- Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
+ public long getFieldOffset(int ordinal) {
+ return startingOffset + nullBitsSize + 8 * ordinal;
}
public void write(int ordinal, boolean value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putBoolean(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeBoolean(offset, value);
}
public void write(int ordinal, byte value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putByte(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeByte(offset, value);
}
public void write(int ordinal, short value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putShort(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeShort(offset, value);
}
public void write(int ordinal, int value) {
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putInt(holder.buffer, offset, value);
+ writeLong(offset, 0L);
+ writeInt(offset, value);
}
public void write(int ordinal, long value) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), value);
+ writeLong(getFieldOffset(ordinal), value);
}
public void write(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
final long offset = getFieldOffset(ordinal);
- Platform.putLong(holder.buffer, offset, 0L);
- Platform.putFloat(holder.buffer, offset, value);
+ writeLong(offset, 0);
+ writeFloat(offset, value);
}
public void write(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
- Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value);
+ writeDouble(getFieldOffset(ordinal), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
// make sure Decimal object has the same scale as DecimalType
- if (input.changePrecision(precision, scale)) {
- Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+ if (input != null && input.changePrecision(precision, scale)) {
+ write(ordinal, input.toUnscaledLong());
} else {
setNullAt(ordinal);
}
@@ -165,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) {
// grow the global buffer before writing data.
holder.grow(16);
- // zero-out the bytes
- Platform.putLong(holder.buffer, holder.cursor, 0L);
- Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
-
// Make sure Decimal object has the same scale as DecimalType.
// Note that we may pass in null Decimal object to set null for it.
if (input == null || !input.changePrecision(precision, scale)) {
- BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ // zero-out the bytes
+ Platform.putLong(getBuffer(), cursor(), 0L);
+ Platform.putLong(getBuffer(), cursor() + 8, 0L);
+
+ BitSetMethods.set(getBuffer(), startingOffset, ordinal);
// keep the offset for future update
- setOffsetAndSize(ordinal, 0L);
+ setOffsetAndSize(ordinal, 0);
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
+ final int numBytes = bytes.length;
+ assert numBytes <= 16;
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes);
setOffsetAndSize(ordinal, bytes.length);
}
// move the cursor forward.
- holder.cursor += 16;
+ increaseCursor(16);
}
}
-
- public void write(int ordinal, UTF8String input) {
- final int numBytes = input.numBytes();
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, byte[] input) {
- write(ordinal, input, 0, input.length);
- }
-
- public void write(int ordinal, byte[] input, int offset, int numBytes) {
- final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
-
- // grow the global buffer before writing data.
- holder.grow(roundedSize);
-
- zeroOutPaddingBytes(numBytes);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset,
- holder.buffer, holder.cursor, numBytes);
-
- setOffsetAndSize(ordinal, numBytes);
-
- // move the cursor forward.
- holder.cursor += roundedSize;
- }
-
- public void write(int ordinal, CalendarInterval input) {
- // grow the global buffer before writing data.
- holder.grow(16);
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(holder.buffer, holder.cursor, input.months);
- Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
-
- setOffsetAndSize(ordinal, 16);
-
- // move the cursor forward.
- holder.cursor += 16;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
new file mode 100644
index 0000000000000..2781655002000
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeMapData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * Base class for writing Unsafe* structures.
+ */
+public abstract class UnsafeWriter {
+ // Keep internal buffer holder
+ protected final BufferHolder holder;
+
+ // The offset of the global buffer where we start to write this structure.
+ protected int startingOffset;
+
+ protected UnsafeWriter(BufferHolder holder) {
+ this.holder = holder;
+ }
+
+ /**
+ * Accessor methods are delegated from BufferHolder class
+ */
+ public final BufferHolder getBufferHolder() {
+ return holder;
+ }
+
+ public final byte[] getBuffer() {
+ return holder.getBuffer();
+ }
+
+ public final void reset() {
+ holder.reset();
+ }
+
+ public final int totalSize() {
+ return holder.totalSize();
+ }
+
+ public final void grow(int neededSize) {
+ holder.grow(neededSize);
+ }
+
+ public final int cursor() {
+ return holder.getCursor();
+ }
+
+ public final void increaseCursor(int val) {
+ holder.increaseCursor(val);
+ }
+
+ public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) {
+ setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int size) {
+ setOffsetAndSize(ordinal, cursor(), size);
+ }
+
+ protected void setOffsetAndSize(int ordinal, int currentCursor, int size) {
+ final long relativeOffset = currentCursor - startingOffset;
+ final long offsetAndSize = (relativeOffset << 32) | (long)size;
+
+ write(ordinal, offsetAndSize);
+ }
+
+ protected final void zeroOutPaddingBytes(int numBytes) {
+ if ((numBytes & 0x07) > 0) {
+ Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L);
+ }
+ }
+
+ public abstract void setNull1Bytes(int ordinal);
+ public abstract void setNull2Bytes(int ordinal);
+ public abstract void setNull4Bytes(int ordinal);
+ public abstract void setNull8Bytes(int ordinal);
+
+ public abstract void write(int ordinal, boolean value);
+ public abstract void write(int ordinal, byte value);
+ public abstract void write(int ordinal, short value);
+ public abstract void write(int ordinal, int value);
+ public abstract void write(int ordinal, long value);
+ public abstract void write(int ordinal, float value);
+ public abstract void write(int ordinal, double value);
+ public abstract void write(int ordinal, Decimal input, int precision, int scale);
+
+ public final void write(int ordinal, UTF8String input) {
+ writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes());
+ }
+
+ public final void write(int ordinal, byte[] input) {
+ write(ordinal, input, 0, input.length);
+ }
+
+ public final void write(int ordinal, byte[] input, int offset, int numBytes) {
+ writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes);
+ }
+
+ private void writeUnalignedBytes(
+ int ordinal,
+ Object baseObject,
+ long baseOffset,
+ int numBytes) {
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ grow(roundedSize);
+ zeroOutPaddingBytes(numBytes);
+ Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes);
+ setOffsetAndSize(ordinal, numBytes);
+ increaseCursor(roundedSize);
+ }
+
+ public final void write(int ordinal, CalendarInterval input) {
+ // grow the global buffer before writing data.
+ grow(16);
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ Platform.putLong(getBuffer(), cursor(), input.months);
+ Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
+
+ setOffsetAndSize(ordinal, 16);
+
+ // move the cursor forward.
+ increaseCursor(16);
+ }
+
+ public final void write(int ordinal, UnsafeRow row) {
+ writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes());
+ }
+
+ public final void write(int ordinal, UnsafeMapData map) {
+ writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes());
+ }
+
+ public final void write(UnsafeArrayData array) {
+ // Unsafe arrays both can be written as a regular array field or as part of a map. This makes
+ // updating the offset and size dependent on the code path, this is why we currently do not
+ // provide an method for writing unsafe arrays that also updates the size and offset.
+ int numBytes = array.getSizeInBytes();
+ grow(numBytes);
+ Platform.copyMemory(
+ array.getBaseObject(),
+ array.getBaseOffset(),
+ getBuffer(),
+ cursor(),
+ numBytes);
+ increaseCursor(numBytes);
+ }
+
+ private void writeAlignedBytes(
+ int ordinal,
+ Object baseObject,
+ long baseOffset,
+ int numBytes) {
+ grow(numBytes);
+ Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes);
+ setOffsetAndSize(ordinal, numBytes);
+ increaseCursor(numBytes);
+ }
+
+ protected final void writeBoolean(long offset, boolean value) {
+ Platform.putBoolean(getBuffer(), offset, value);
+ }
+
+ protected final void writeByte(long offset, byte value) {
+ Platform.putByte(getBuffer(), offset, value);
+ }
+
+ protected final void writeShort(long offset, short value) {
+ Platform.putShort(getBuffer(), offset, value);
+ }
+
+ protected final void writeInt(long offset, int value) {
+ Platform.putInt(getBuffer(), offset, value);
+ }
+
+ protected final void writeLong(long offset, long value) {
+ Platform.putLong(getBuffer(), offset, value);
+ }
+
+ protected final void writeFloat(long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
+ Platform.putFloat(getBuffer(), offset, value);
+ }
+
+ protected final void writeDouble(long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
+ Platform.putDouble(getBuffer(), offset, value);
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
index d224332d8a6c9..023ec139652c5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
@@ -21,6 +21,9 @@
import java.io.Reader;
import javax.xml.namespace.QName;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
+import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpression;
@@ -37,9 +40,15 @@
* This is based on Hive's UDFXPathUtil implementation.
*/
public class UDFXPathUtil {
+ public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/";
+ public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities";
+ public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities";
+ private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
+ private DocumentBuilder builder = null;
private XPath xpath = XPathFactory.newInstance().newXPath();
private ReusableStringReader reader = new ReusableStringReader();
private InputSource inputSource = new InputSource(reader);
+
private XPathExpression expression = null;
private String oldPath = null;
@@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE
return null;
}
+ if (builder == null){
+ try {
+ initializeDocumentBuilderFactory();
+ builder = dbf.newDocumentBuilder();
+ } catch (ParserConfigurationException e) {
+ throw new RuntimeException(
+ "Error instantiating DocumentBuilder, cannot build xml parser", e);
+ }
+ }
+
reader.set(xml);
try {
- return expression.evaluate(inputSource, qname);
+ return expression.evaluate(builder.parse(inputSource), qname);
} catch (XPathExpressionException e) {
throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e);
+ } catch (Exception e) {
+ throw new RuntimeException("Error loading expression '" + oldPath + "'", e);
}
}
+ private void initializeDocumentBuilderFactory() throws ParserConfigurationException {
+ dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false);
+ dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false);
+ }
+
public Boolean evalBoolean(String xml, String path) throws XPathExpressionException {
return (Boolean) eval(xml, path, XPathConstants.BOOLEAN);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index ccdb6bc5d4b7c..7b02317b8538f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -68,10 +68,10 @@ import org.apache.spark.sql.types._
*/
@Experimental
@InterfaceStability.Evolving
-@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " +
- "(Int, String, etc) and Product types (case classes) are supported by importing " +
- "spark.implicits._ Support for serializing other types will be added in future " +
- "releases.")
+@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +
+ "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " +
+ "classes) are supported by importing spark.implicits._ Support for serializing other types " +
+ "will be added in future releases.")
trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index 0b95a8821b05a..b47ec0b72c638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -132,7 +132,7 @@ object Encoders {
* - primitive types: boolean, int, double, etc.
* - boxed types: Boolean, Integer, Double, etc.
* - String
- * - java.math.BigDecimal
+ * - java.math.BigDecimal, java.math.BigInteger
* - time related: java.sql.Date, java.sql.Timestamp
* - collection types: only array and java.util.List currently, map support is in progress
* - nested java bean.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 474ec592201d9..9e9105a157abe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -170,6 +170,9 @@ object CatalystTypeConverters {
convertedIterable += elementConverter.toCatalyst(item)
}
new GenericArrayData(convertedIterable.toArray)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to an array of ${elementType.catalogString}")
}
}
@@ -206,6 +209,10 @@ object CatalystTypeConverters {
scalaValue match {
case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction)
case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + "cannot be converted to a map type with "
+ + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})")
}
}
@@ -252,6 +259,9 @@ object CatalystTypeConverters {
idx += 1
}
new GenericInternalRow(ar)
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to ${structType.catalogString}")
}
override def toScala(row: InternalRow): Row = {
@@ -276,6 +286,9 @@ object CatalystTypeConverters {
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
case str: String => UTF8String.fromString(str)
case utf8: UTF8String => utf8
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to the string type")
}
override def toScala(catalystValue: UTF8String): String =
if (catalystValue == null) null else catalystValue.toString
@@ -309,6 +322,9 @@ object CatalystTypeConverters {
case d: JavaBigDecimal => Decimal(d)
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
+ case other => throw new IllegalArgumentException(
+ s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ + s"cannot be converted to ${dataType.catalogString}")
}
decimal.toPrecision(dataType.precision, dataType.scale)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 29110640d64f2..274d75e680f03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -119,4 +119,28 @@ object InternalRow {
case v: MapData => v.copy()
case _ => value
}
+
+ /**
+ * Returns an accessor for an `InternalRow` with given data type. The returned accessor
+ * actually takes a `SpecializedGetters` input because it can be generalized to other classes
+ * that implements `SpecializedGetters` (e.g., `ArrayData`) too.
+ */
+ def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
+ case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
+ case ByteType => (input, ordinal) => input.getByte(ordinal)
+ case ShortType => (input, ordinal) => input.getShort(ordinal)
+ case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
+ case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
+ case FloatType => (input, ordinal) => input.getFloat(ordinal)
+ case DoubleType => (input, ordinal) => input.getDouble(ordinal)
+ case StringType => (input, ordinal) => input.getUTF8String(ordinal)
+ case BinaryType => (input, ordinal) => input.getBinary(ordinal)
+ case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
+ case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
+ case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
+ case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
+ case _: MapType => (input, ordinal) => input.getMap(ordinal)
+ case u: UserDefinedType[_] => getAccessor(u.sqlType)
+ case _ => (input, ordinal) => input.get(ordinal, dataType)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9a4bf0075a178..f9acc208b715e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql.catalyst
+import java.lang.reflect.Constructor
+
+import org.apache.commons.lang3.reflect.ConstructorUtils
+
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -382,22 +386,22 @@ object ScalaReflection extends ScalaReflection {
val clsName = getClassNameFromType(fieldType)
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
- if (cls.getName startsWith "scala.Tuple") {
+ val constructor = if (cls.getName startsWith "scala.Tuple") {
deserializerFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
- val constructor = deserializerFor(
+ deserializerFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
+ }
- if (!nullable) {
- AssertNotNull(constructor, newTypePath)
- } else {
- constructor
- }
+ if (!nullable) {
+ AssertNotNull(constructor, newTypePath)
+ } else {
+ constructor
}
}
@@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection {
}
}
+ /**
+ * Finds an accessible constructor with compatible parameters. This is a more flexible search
+ * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
+ * matching constructor is returned. Otherwise, it returns `None`.
+ */
+ def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
+ Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
+ }
+
/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
@@ -794,6 +807,65 @@ object ScalaReflection extends ScalaReflection {
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
"throws", "transient", "true", "try", "void", "volatile", "while")
+
+ val typeJavaMapping = Map[DataType, Class[_]](
+ BooleanType -> classOf[Boolean],
+ ByteType -> classOf[Byte],
+ ShortType -> classOf[Short],
+ IntegerType -> classOf[Int],
+ LongType -> classOf[Long],
+ FloatType -> classOf[Float],
+ DoubleType -> classOf[Double],
+ StringType -> classOf[UTF8String],
+ DateType -> classOf[DateType.InternalType],
+ TimestampType -> classOf[TimestampType.InternalType],
+ BinaryType -> classOf[BinaryType.InternalType],
+ CalendarIntervalType -> classOf[CalendarInterval]
+ )
+
+ val typeBoxedJavaMapping = Map[DataType, Class[_]](
+ BooleanType -> classOf[java.lang.Boolean],
+ ByteType -> classOf[java.lang.Byte],
+ ShortType -> classOf[java.lang.Short],
+ IntegerType -> classOf[java.lang.Integer],
+ LongType -> classOf[java.lang.Long],
+ FloatType -> classOf[java.lang.Float],
+ DoubleType -> classOf[java.lang.Double],
+ DateType -> classOf[java.lang.Integer],
+ TimestampType -> classOf[java.lang.Long]
+ )
+
+ def dataTypeJavaClass(dt: DataType): Class[_] = {
+ dt match {
+ case _: DecimalType => classOf[Decimal]
+ case _: StructType => classOf[InternalRow]
+ case _: ArrayType => classOf[ArrayData]
+ case _: MapType => classOf[MapData]
+ case ObjectType(cls) => cls
+ case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object])
+ }
+ }
+
+ def javaBoxedType(dt: DataType): Class[_] = dt match {
+ case _: DecimalType => classOf[Decimal]
+ case BinaryType => classOf[Array[Byte]]
+ case StringType => classOf[UTF8String]
+ case CalendarIntervalType => classOf[CalendarInterval]
+ case _: StructType => classOf[InternalRow]
+ case _: ArrayType => classOf[ArrayType]
+ case _: MapType => classOf[MapType]
+ case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
+ case ObjectType(cls) => cls
+ case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
+ }
+
+ def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
+ if (arguments != Nil) {
+ arguments.map(e => dataTypeJavaClass(e.dataType))
+ } else {
+ Seq.empty
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7848f88bda1c9..6e3107f1c6f75 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
@@ -177,6 +178,7 @@ class Analyzer(
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveTimeZone(conf) ::
+ ResolvedUuidExpressions ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
@@ -273,9 +275,9 @@ class Analyzer(
case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.copy(aggregations = assignAliases(g.aggregations))
- case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
- if child.resolved && hasUnresolvedAlias(groupByExprs) =>
- Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
+ case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
+ if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
+ Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
@@ -502,9 +504,20 @@ class Analyzer(
object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
- | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
- case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
+ case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
+ || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
+ || !p.pivotColumn.resolved => p
+ case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
+ // Check all aggregate expressions.
+ aggregates.foreach { e =>
+ if (!isAggregateExpression(e)) {
+ throw new AnalysisException(
+ s"Aggregate expression required for pivot, found '$e'")
+ }
+ }
+ // Group-by expressions coming from SQL are implicit and need to be deduced.
+ val groupByExprs = groupByExprsOpt.getOrElse(
+ (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
val singleAgg = aggregates.size == 1
def outputName(value: Literal, aggregate: Expression): String = {
val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
@@ -566,16 +579,20 @@ class Analyzer(
// TODO: Don't construct the physical container until after analysis.
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
- if (filteredAggregate.fastEquals(aggregate)) {
- throw new AnalysisException(
- s"Aggregate expression required for pivot, found '$aggregate'")
- }
Alias(filteredAggregate, outputName(value, aggregate))()
}
}
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
}
+
+ private def isAggregateExpression(expr: Expression): Boolean = {
+ expr match {
+ case Alias(e, _) => isAggregateExpression(e)
+ case AggregateExpression(_, _, _, _) => true
+ case _ => false
+ }
+ }
}
/**
@@ -659,13 +676,13 @@ class Analyzer(
try {
catalog.lookupRelation(tableIdentWithDb)
} catch {
- case _: NoSuchTableException =>
- u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}")
+ case e: NoSuchTableException =>
+ u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e)
// If the database is defined and that database is not found, throw an AnalysisException.
// Note that if the database is not defined, it is possible we are looking up a temp view.
case e: NoSuchDatabaseException =>
u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " +
- s"database ${e.db} doesn't exist.")
+ s"database ${e.db} doesn't exist.", e)
}
}
@@ -1722,15 +1739,16 @@ class Analyzer(
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
* all regular expressions.
- * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s.
- * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts
- * it into the plan tree.
+ * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s
+ * and [[WindowFunctionType]]s.
+ * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a
+ * [[Window]] operator and inserts it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
- private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
- projectList.exists(hasWindowFunction)
+ private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
+ exprs.exists(hasWindowFunction)
- private def hasWindowFunction(expr: NamedExpression): Boolean = {
+ private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
@@ -1813,6 +1831,10 @@ class Analyzer(
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
+ case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
+ failAnalysis("It is not allowed to use a window function inside an aggregate " +
+ "function. Please use the inner window function in a sub-query.")
+
// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
@@ -1880,7 +1902,7 @@ class Analyzer(
s"Please file a bug report with this error message, stack trace, and the query.")
} else {
val spec = distinctWindowSpec.head
- (spec.partitionSpec, spec.orderSpec)
+ (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr))
}
}.toSeq
@@ -1888,7 +1910,7 @@ class Analyzer(
// setting this to the child of the next Window operator.
val windowOps =
groupedWindowExpressions.foldLeft(child) {
- case (last, ((partitionSpec, orderSpec), windowExpressions)) =>
+ case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
Window(windowExpressions, partitionSpec, orderSpec, last)
}
@@ -1994,6 +2016,20 @@ class Analyzer(
}
}
+ /**
+ * Set the seed for random number generation in Uuid expressions.
+ */
+ object ResolvedUuidExpressions extends Rule[LogicalPlan] {
+ private lazy val random = new Random()
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
+ case p if p.resolved => p
+ case p => p transformExpressionsUp {
+ case Uuid(None) => Uuid(Some(random.nextLong()))
+ }
+ }
+ }
+
/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 90bda2a72ad82..af256b98b34f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
@@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("An offset window function can only be evaluated in an ordered " +
s"row-based window frame with a single offset: $w")
+ case _ @ WindowExpression(_: PythonUDF,
+ WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
+ if !frame.isUnbounded =>
+ failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
+
case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
- // function.
+ // function or a Pandas window UDF.
e match {
case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction =>
w
+ case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) =>
+ w
case _ =>
failAnalysis(s"Expression '$e' not supported within a window function.")
}
@@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression) = {
- expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
+ expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 747016beb06e7..3700c63d817ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -299,6 +299,15 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
expression[CountMinSketchAgg]("count_min_sketch"),
+ expression[RegrCount]("regr_count"),
+ expression[RegrSXX]("regr_sxx"),
+ expression[RegrSYY]("regr_syy"),
+ expression[RegrAvgX]("regr_avgx"),
+ expression[RegrAvgY]("regr_avgy"),
+ expression[RegrSXY]("regr_sxy"),
+ expression[RegrSlope]("regr_slope"),
+ expression[RegrR2]("regr_r2"),
+ expression[RegrIntercept]("regr_intercept"),
// string functions
expression[Ascii]("ascii"),
@@ -308,7 +317,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
- expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
@@ -336,7 +344,6 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
- expression[StringReverse]("reverse"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
@@ -395,6 +402,7 @@ object FunctionRegistry {
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
+ expression[WeekDay]("weekday"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
expression[TimeWindow]("window"),
@@ -402,14 +410,39 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
+ expression[ArraysOverlap]("arrays_overlap"),
+ expression[ArrayJoin]("array_join"),
+ expression[ArrayPosition]("array_position"),
+ expression[ArraySort]("array_sort"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
+ expression[ElementAt]("element_at"),
+ expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
+ expression[MapEntries]("map_entries"),
expression[Size]("size"),
+ expression[Slice]("slice"),
+ expression[Size]("cardinality"),
+ expression[ArraysZip]("arrays_zip"),
expression[SortArray]("sort_array"),
+ expression[ArrayMin]("array_min"),
+ expression[ArrayMax]("array_max"),
+ expression[Reverse]("reverse"),
+ expression[Concat]("concat"),
+ expression[Flatten]("flatten"),
+ expression[ArrayRepeat]("array_repeat"),
+ expression[ArrayRemove]("array_remove"),
CreateStruct.registryEntry,
+ // mask functions
+ expression[Mask]("mask"),
+ expression[MaskFirstN]("mask_first_n"),
+ expression[MaskLastN]("mask_last_n"),
+ expression[MaskShowFirstN]("mask_show_first_n"),
+ expression[MaskShowLastN]("mask_show_last_n"),
+ expression[MaskHash]("mask_hash"),
+
// misc functions
expression[AssertTrue]("assert_true"),
expression[Crc32]("crc32"),
@@ -526,7 +559,9 @@ object FunctionRegistry {
// Otherwise, find a constructor method that matches the number of arguments, and use that.
val params = Seq.fill(expressions.size)(classOf[Expression])
val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
- val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted
+ val validParametersCount = constructors
+ .filter(_.getParameterTypes.forall(_ == classOf[Expression]))
+ .map(_.getParameterCount).distinct.sorted
val expectedNumberOfParameters = if (validParametersCount.length == 1) {
validParametersCount.head.toString
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index f2df3e132629f..71ed75454cd4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -103,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
castedExpr.eval()
} catch {
case NonFatal(ex) =>
- table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
+ table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
}
})
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e8669c4637d06..b2817b0538a7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -47,9 +47,9 @@ import org.apache.spark.sql.types._
object TypeCoercion {
def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
- InConversion ::
+ InConversion(conf) ::
WidenSetOperationTypes ::
- PromoteStrings ::
+ PromoteStrings(conf) ::
DecimalPrecision ::
BooleanEquality ::
FunctionArgumentConversion ::
@@ -59,7 +59,7 @@ object TypeCoercion {
IfCoercion ::
StackCoercion ::
Division ::
- ImplicitTypeCasts ::
+ new ImplicitTypeCasts(conf) ::
DateTimeOperations ::
WindowFrameCoercion ::
Nil
@@ -112,6 +112,14 @@ object TypeCoercion {
StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
}))
+ case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
+ findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
+
+ case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
+ val keyType = findTightestCommonType(kt1, kt2)
+ val valueType = findTightestCommonType(vt1, vt2)
+ Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
+
case _ => None
}
@@ -127,7 +135,8 @@ object TypeCoercion {
* is a String and the other is not. It also handles when one op is a Date and the
* other is a Timestamp by making the target type to be String.
*/
- val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
+ private def findCommonTypeForBinaryComparison(
+ dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match {
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
@@ -135,11 +144,17 @@ object TypeCoercion {
case (DateType, StringType) => Some(StringType)
case (StringType, TimestampType) => Some(StringType)
case (TimestampType, StringType) => Some(StringType)
- case (TimestampType, DateType) => Some(StringType)
- case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
+ // Cast to TimestampType when we compare DateType with TimestampType
+ // if conf.compareDateTimestampInTimestamp is true
+ // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true
+ case (TimestampType, DateType)
+ => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType)
+ case (DateType, TimestampType)
+ => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType)
+
// There is no proper decimal type we can pick,
// using double type is the best we can do.
// See SPARK-22469 for details.
@@ -147,7 +162,7 @@ object TypeCoercion {
case (s: StringType, n: DecimalType) => Some(DoubleType)
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
- case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
+ case (l: AtomicType, r: StringType) if l != StringType => Some(l)
case (l, r) => None
}
@@ -168,11 +183,27 @@ object TypeCoercion {
})
}
+ /**
+ * Whether the data type contains StringType.
+ */
+ def hasStringType(dt: DataType): Boolean = dt match {
+ case StringType => true
+ case ArrayType(et, _) => hasStringType(et)
+ // Add StructType if we support string promotion for struct fields in the future.
+ case _ => false
+ }
+
private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
- types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
- case Some(d) => findWiderTypeForTwo(d, c)
- case None => None
- })
+ // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal
+ // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType.
+ // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance,
+ // (TimestampType, IntegerType, StringType) should have StringType as the wider common type.
+ val (stringTypes, nonStringTypes) = types.partition(hasStringType(_))
+ (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) =>
+ r match {
+ case Some(d) => findWiderTypeForTwo(d, c)
+ case _ => None
+ })
}
/**
@@ -313,7 +344,7 @@ object TypeCoercion {
/**
* Promotes strings that appear in arithmetic expressions.
*/
- object PromoteStrings extends TypeCoercionRule {
+ case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
@@ -342,8 +373,8 @@ object TypeCoercion {
p.makeCopy(Array(left, Cast(right, TimestampType)))
case p @ BinaryComparison(left, right)
- if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
- val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
+ if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
+ val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
case Abs(e @ StringType()) => Abs(Cast(e, DoubleType))
@@ -374,7 +405,7 @@ object TypeCoercion {
* operator type is found the original expression will be returned and an
* Analysis Exception will be raised at the type checking phase.
*/
- object InConversion extends TypeCoercionRule {
+ case class InConversion(conf: SQLConf) extends TypeCoercionRule {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
@@ -400,7 +431,7 @@ object TypeCoercion {
val rhs = sub.output
val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
- findCommonTypeForBinaryComparison(l.dataType, r.dataType)
+ findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf)
.orElse(findTightestCommonType(l.dataType, r.dataType))
}
@@ -497,6 +528,14 @@ object TypeCoercion {
case None => a
}
+ case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
+ !haveSameType(children) =>
+ val types = children.map(_.dataType)
+ findWiderCommonType(types) match {
+ case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
+ case None => c
+ }
+
case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
@@ -737,12 +776,33 @@ object TypeCoercion {
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
- object ImplicitTypeCasts extends TypeCoercionRule {
+ class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
+
+ private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING)
+
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
+ // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp
+ // string is in a specific timezone, so the string itself should not contain timezone.
+ // TODO: We should move the type coercion logic to expressions instead of a central
+ // place to put all the rules.
+ case e: FromUTCTimestamp if e.left.dataType == StringType =>
+ if (rejectTzInString) {
+ e.copy(left = StringToTimestampWithoutTimezone(e.left))
+ } else {
+ e.copy(left = Cast(e.left, TimestampType))
+ }
+
+ case e: ToUTCTimestamp if e.left.dataType == StringType =>
+ if (rejectTzInString) {
+ e.copy(left = StringToTimestampWithoutTimezone(e.left))
+ } else {
+ e.copy(left = Cast(e.left, TimestampType))
+ }
+
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonType(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
@@ -759,7 +819,7 @@ object TypeCoercion {
case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
- implicitCast(in, expected).getOrElse(in)
+ ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in)
}
e.withNewChildren(children)
@@ -775,6 +835,9 @@ object TypeCoercion {
}
e.withNewChildren(children)
}
+ }
+
+ object ImplicitTypeCasts {
/**
* Given an expected data type, try to cast the expression and return the cast expression.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index b55043c270644..2bed41672fe33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
/**
@@ -345,7 +346,8 @@ object UnsupportedOperationChecker {
plan.foreachUp { implicit subPlan =>
subPlan match {
case (_: Project | _: Filter | _: MapElements | _: MapPartitions |
- _: DeserializeToObject | _: SerializeFromObject) =>
+ _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias |
+ _: TypedFilter) =>
case node if node.nodeName == "StreamingRelationV2" =>
case node =>
throwError(s"Continuous processing does not support ${node.nodeName} operations.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 7731336d247db..354a3fa0602a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -41,6 +41,11 @@ package object analysis {
def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
}
+
+ /** Fails the analysis at the point where a specific tree node was parsed. */
+ def failAnalysis(msg: String, cause: Throwable): Nothing = {
+ throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause))
+ }
}
/** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a65f58fa61ff4..71e23175168e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.ParserUtils
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens
* @param names the names to be associated with each output of computing [[child]].
*/
case class MultiAlias(child: Expression, names: Seq[String])
- extends UnaryExpression with NamedExpression with CodegenFallback {
+ extends UnaryExpression with NamedExpression with Unevaluable {
override def name: String = throw new UnresolvedException(this, "name")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
index 45b4f013620c1..1a145c24d78cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.catalyst.catalog
-import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException}
+import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ListenerBus
/**
* Interface for the system catalog (of functions, partitions, tables, and databases).
@@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus
*
* Implementations should throw [[NoSuchDatabaseException]] when databases don't exist.
*/
-abstract class ExternalCatalog
- extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] {
+trait ExternalCatalog {
import CatalogTypes.TablePartitionSpec
+ // --------------------------------------------------------------------------
+ // Utils
+ // --------------------------------------------------------------------------
+
protected def requireDbExists(db: String): Unit = {
if (!databaseExists(db)) {
throw new NoSuchDatabaseException(db)
@@ -63,22 +65,9 @@ abstract class ExternalCatalog
// Databases
// --------------------------------------------------------------------------
- final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
- val db = dbDefinition.name
- postToAll(CreateDatabasePreEvent(db))
- doCreateDatabase(dbDefinition, ignoreIfExists)
- postToAll(CreateDatabaseEvent(db))
- }
+ def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit
- protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit
-
- final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {
- postToAll(DropDatabasePreEvent(db))
- doDropDatabase(db, ignoreIfNotExists, cascade)
- postToAll(DropDatabaseEvent(db))
- }
-
- protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit
+ def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit
/**
* Alter a database whose name matches the one specified in `dbDefinition`,
@@ -87,14 +76,7 @@ abstract class ExternalCatalog
* Note: If the underlying implementation does not support altering a certain field,
* this becomes a no-op.
*/
- final def alterDatabase(dbDefinition: CatalogDatabase): Unit = {
- val db = dbDefinition.name
- postToAll(AlterDatabasePreEvent(db))
- doAlterDatabase(dbDefinition)
- postToAll(AlterDatabaseEvent(db))
- }
-
- protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit
+ def alterDatabase(dbDefinition: CatalogDatabase): Unit
def getDatabase(db: String): CatalogDatabase
@@ -110,41 +92,15 @@ abstract class ExternalCatalog
// Tables
// --------------------------------------------------------------------------
- final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
- val db = tableDefinition.database
- val name = tableDefinition.identifier.table
- val tableDefinitionWithVersion =
- tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION)
- postToAll(CreateTablePreEvent(db, name))
- doCreateTable(tableDefinitionWithVersion, ignoreIfExists)
- postToAll(CreateTableEvent(db, name))
- }
-
- protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit
-
- final def dropTable(
- db: String,
- table: String,
- ignoreIfNotExists: Boolean,
- purge: Boolean): Unit = {
- postToAll(DropTablePreEvent(db, table))
- doDropTable(db, table, ignoreIfNotExists, purge)
- postToAll(DropTableEvent(db, table))
- }
+ def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit
- protected def doDropTable(
+ def dropTable(
db: String,
table: String,
ignoreIfNotExists: Boolean,
purge: Boolean): Unit
- final def renameTable(db: String, oldName: String, newName: String): Unit = {
- postToAll(RenameTablePreEvent(db, oldName, newName))
- doRenameTable(db, oldName, newName)
- postToAll(RenameTableEvent(db, oldName, newName))
- }
-
- protected def doRenameTable(db: String, oldName: String, newName: String): Unit
+ def renameTable(db: String, oldName: String, newName: String): Unit
/**
* Alter a table whose database and name match the ones specified in `tableDefinition`, assuming
@@ -154,15 +110,7 @@ abstract class ExternalCatalog
* Note: If the underlying implementation does not support altering a certain field,
* this becomes a no-op.
*/
- final def alterTable(tableDefinition: CatalogTable): Unit = {
- val db = tableDefinition.database
- val name = tableDefinition.identifier.table
- postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE))
- doAlterTable(tableDefinition)
- postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE))
- }
-
- protected def doAlterTable(tableDefinition: CatalogTable): Unit
+ def alterTable(tableDefinition: CatalogTable): Unit
/**
* Alter the data schema of a table identified by the provided database and table name. The new
@@ -173,22 +121,10 @@ abstract class ExternalCatalog
* @param table Name of table to alter schema for
* @param newDataSchema Updated data schema to be used for the table.
*/
- final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = {
- postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA))
- doAlterTableDataSchema(db, table, newDataSchema)
- postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA))
- }
-
- protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit
+ def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit
/** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */
- final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = {
- postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS))
- doAlterTableStats(db, table, stats)
- postToAll(AlterTableEvent(db, table, AlterTableKind.STATS))
- }
-
- protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit
+ def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit
def getTable(db: String, table: String): CatalogTable
@@ -340,49 +276,17 @@ abstract class ExternalCatalog
// Functions
// --------------------------------------------------------------------------
- final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = {
- val name = funcDefinition.identifier.funcName
- postToAll(CreateFunctionPreEvent(db, name))
- doCreateFunction(db, funcDefinition)
- postToAll(CreateFunctionEvent(db, name))
- }
+ def createFunction(db: String, funcDefinition: CatalogFunction): Unit
- protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit
+ def dropFunction(db: String, funcName: String): Unit
- final def dropFunction(db: String, funcName: String): Unit = {
- postToAll(DropFunctionPreEvent(db, funcName))
- doDropFunction(db, funcName)
- postToAll(DropFunctionEvent(db, funcName))
- }
+ def alterFunction(db: String, funcDefinition: CatalogFunction): Unit
- protected def doDropFunction(db: String, funcName: String): Unit
-
- final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = {
- val name = funcDefinition.identifier.funcName
- postToAll(AlterFunctionPreEvent(db, name))
- doAlterFunction(db, funcDefinition)
- postToAll(AlterFunctionEvent(db, name))
- }
-
- protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit
-
- final def renameFunction(db: String, oldName: String, newName: String): Unit = {
- postToAll(RenameFunctionPreEvent(db, oldName, newName))
- doRenameFunction(db, oldName, newName)
- postToAll(RenameFunctionEvent(db, oldName, newName))
- }
-
- protected def doRenameFunction(db: String, oldName: String, newName: String): Unit
+ def renameFunction(db: String, oldName: String, newName: String): Unit
def getFunction(db: String, funcName: String): CatalogFunction
def functionExists(db: String, funcName: String): Boolean
def listFunctions(db: String, pattern: String): Seq[String]
-
- override protected def doPostEvent(
- listener: ExternalCatalogEventListener,
- event: ExternalCatalogEvent): Unit = {
- listener.onEvent(event)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala
new file mode 100644
index 0000000000000..2f009be5816fa
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ListenerBus
+
+/**
+ * Wraps an ExternalCatalog to provide listener events.
+ */
+class ExternalCatalogWithListener(delegate: ExternalCatalog)
+ extends ExternalCatalog
+ with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] {
+ import CatalogTypes.TablePartitionSpec
+
+ def unwrapped: ExternalCatalog = delegate
+
+ override protected def doPostEvent(
+ listener: ExternalCatalogEventListener,
+ event: ExternalCatalogEvent): Unit = {
+ listener.onEvent(event)
+ }
+
+ // --------------------------------------------------------------------------
+ // Databases
+ // --------------------------------------------------------------------------
+
+ override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
+ val db = dbDefinition.name
+ postToAll(CreateDatabasePreEvent(db))
+ delegate.createDatabase(dbDefinition, ignoreIfExists)
+ postToAll(CreateDatabaseEvent(db))
+ }
+
+ override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {
+ postToAll(DropDatabasePreEvent(db))
+ delegate.dropDatabase(db, ignoreIfNotExists, cascade)
+ postToAll(DropDatabaseEvent(db))
+ }
+
+ override def alterDatabase(dbDefinition: CatalogDatabase): Unit = {
+ val db = dbDefinition.name
+ postToAll(AlterDatabasePreEvent(db))
+ delegate.alterDatabase(dbDefinition)
+ postToAll(AlterDatabaseEvent(db))
+ }
+
+ override def getDatabase(db: String): CatalogDatabase = {
+ delegate.getDatabase(db)
+ }
+
+ override def databaseExists(db: String): Boolean = {
+ delegate.databaseExists(db)
+ }
+
+ override def listDatabases(): Seq[String] = {
+ delegate.listDatabases()
+ }
+
+ override def listDatabases(pattern: String): Seq[String] = {
+ delegate.listDatabases(pattern)
+ }
+
+ override def setCurrentDatabase(db: String): Unit = {
+ delegate.setCurrentDatabase(db)
+ }
+
+ // --------------------------------------------------------------------------
+ // Tables
+ // --------------------------------------------------------------------------
+
+ override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
+ val db = tableDefinition.database
+ val name = tableDefinition.identifier.table
+ val tableDefinitionWithVersion =
+ tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION)
+ postToAll(CreateTablePreEvent(db, name))
+ delegate.createTable(tableDefinitionWithVersion, ignoreIfExists)
+ postToAll(CreateTableEvent(db, name))
+ }
+
+ override def dropTable(
+ db: String,
+ table: String,
+ ignoreIfNotExists: Boolean,
+ purge: Boolean): Unit = {
+ postToAll(DropTablePreEvent(db, table))
+ delegate.dropTable(db, table, ignoreIfNotExists, purge)
+ postToAll(DropTableEvent(db, table))
+ }
+
+ override def renameTable(db: String, oldName: String, newName: String): Unit = {
+ postToAll(RenameTablePreEvent(db, oldName, newName))
+ delegate.renameTable(db, oldName, newName)
+ postToAll(RenameTableEvent(db, oldName, newName))
+ }
+
+ override def alterTable(tableDefinition: CatalogTable): Unit = {
+ val db = tableDefinition.database
+ val name = tableDefinition.identifier.table
+ postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE))
+ delegate.alterTable(tableDefinition)
+ postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE))
+ }
+
+ override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = {
+ postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA))
+ delegate.alterTableDataSchema(db, table, newDataSchema)
+ postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA))
+ }
+
+ override def alterTableStats(
+ db: String,
+ table: String,
+ stats: Option[CatalogStatistics]): Unit = {
+ postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS))
+ delegate.alterTableStats(db, table, stats)
+ postToAll(AlterTableEvent(db, table, AlterTableKind.STATS))
+ }
+
+ override def getTable(db: String, table: String): CatalogTable = {
+ delegate.getTable(db, table)
+ }
+
+ override def tableExists(db: String, table: String): Boolean = {
+ delegate.tableExists(db, table)
+ }
+
+ override def listTables(db: String): Seq[String] = {
+ delegate.listTables(db)
+ }
+
+ override def listTables(db: String, pattern: String): Seq[String] = {
+ delegate.listTables(db, pattern)
+ }
+
+ override def loadTable(
+ db: String,
+ table: String,
+ loadPath: String,
+ isOverwrite: Boolean,
+ isSrcLocal: Boolean): Unit = {
+ delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal)
+ }
+
+ override def loadPartition(
+ db: String,
+ table: String,
+ loadPath: String,
+ partition: TablePartitionSpec,
+ isOverwrite: Boolean,
+ inheritTableSpecs: Boolean,
+ isSrcLocal: Boolean): Unit = {
+ delegate.loadPartition(
+ db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal)
+ }
+
+ override def loadDynamicPartitions(
+ db: String,
+ table: String,
+ loadPath: String,
+ partition: TablePartitionSpec,
+ replace: Boolean,
+ numDP: Int): Unit = {
+ delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP)
+ }
+
+ // --------------------------------------------------------------------------
+ // Partitions
+ // --------------------------------------------------------------------------
+
+ override def createPartitions(
+ db: String,
+ table: String,
+ parts: Seq[CatalogTablePartition],
+ ignoreIfExists: Boolean): Unit = {
+ delegate.createPartitions(db, table, parts, ignoreIfExists)
+ }
+
+ override def dropPartitions(
+ db: String,
+ table: String,
+ partSpecs: Seq[TablePartitionSpec],
+ ignoreIfNotExists: Boolean,
+ purge: Boolean,
+ retainData: Boolean): Unit = {
+ delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData)
+ }
+
+ override def renamePartitions(
+ db: String,
+ table: String,
+ specs: Seq[TablePartitionSpec],
+ newSpecs: Seq[TablePartitionSpec]): Unit = {
+ delegate.renamePartitions(db, table, specs, newSpecs)
+ }
+
+ override def alterPartitions(
+ db: String,
+ table: String,
+ parts: Seq[CatalogTablePartition]): Unit = {
+ delegate.alterPartitions(db, table, parts)
+ }
+
+ override def getPartition(
+ db: String,
+ table: String,
+ spec: TablePartitionSpec): CatalogTablePartition = {
+ delegate.getPartition(db, table, spec)
+ }
+
+ override def getPartitionOption(
+ db: String,
+ table: String,
+ spec: TablePartitionSpec): Option[CatalogTablePartition] = {
+ delegate.getPartitionOption(db, table, spec)
+ }
+
+ override def listPartitionNames(
+ db: String,
+ table: String,
+ partialSpec: Option[TablePartitionSpec] = None): Seq[String] = {
+ delegate.listPartitionNames(db, table, partialSpec)
+ }
+
+ override def listPartitions(
+ db: String,
+ table: String,
+ partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
+ delegate.listPartitions(db, table, partialSpec)
+ }
+
+ override def listPartitionsByFilter(
+ db: String,
+ table: String,
+ predicates: Seq[Expression],
+ defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
+ delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId)
+ }
+
+ // --------------------------------------------------------------------------
+ // Functions
+ // --------------------------------------------------------------------------
+
+ override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = {
+ val name = funcDefinition.identifier.funcName
+ postToAll(CreateFunctionPreEvent(db, name))
+ delegate.createFunction(db, funcDefinition)
+ postToAll(CreateFunctionEvent(db, name))
+ }
+
+ override def dropFunction(db: String, funcName: String): Unit = {
+ postToAll(DropFunctionPreEvent(db, funcName))
+ delegate.dropFunction(db, funcName)
+ postToAll(DropFunctionEvent(db, funcName))
+ }
+
+ override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = {
+ val name = funcDefinition.identifier.funcName
+ postToAll(AlterFunctionPreEvent(db, name))
+ delegate.alterFunction(db, funcDefinition)
+ postToAll(AlterFunctionEvent(db, name))
+ }
+
+ override def renameFunction(db: String, oldName: String, newName: String): Unit = {
+ postToAll(RenameFunctionPreEvent(db, oldName, newName))
+ delegate.renameFunction(db, oldName, newName)
+ postToAll(RenameFunctionEvent(db, oldName, newName))
+ }
+
+ override def getFunction(db: String, funcName: String): CatalogFunction = {
+ delegate.getFunction(db, funcName)
+ }
+
+ override def functionExists(db: String, funcName: String): Boolean = {
+ delegate.functionExists(db, funcName)
+ }
+
+ override def listFunctions(db: String, pattern: String): Seq[String] = {
+ delegate.listFunctions(db, pattern)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 8eacfa058bd52..741dc46b07382 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -98,7 +98,7 @@ class InMemoryCatalog(
// Databases
// --------------------------------------------------------------------------
- override protected def doCreateDatabase(
+ override def createDatabase(
dbDefinition: CatalogDatabase,
ignoreIfExists: Boolean): Unit = synchronized {
if (catalog.contains(dbDefinition.name)) {
@@ -119,7 +119,7 @@ class InMemoryCatalog(
}
}
- override protected def doDropDatabase(
+ override def dropDatabase(
db: String,
ignoreIfNotExists: Boolean,
cascade: Boolean): Unit = synchronized {
@@ -152,7 +152,7 @@ class InMemoryCatalog(
}
}
- override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized {
+ override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized {
requireDbExists(dbDefinition.name)
catalog(dbDefinition.name).db = dbDefinition
}
@@ -180,7 +180,7 @@ class InMemoryCatalog(
// Tables
// --------------------------------------------------------------------------
- override protected def doCreateTable(
+ override def createTable(
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = synchronized {
assert(tableDefinition.identifier.database.isDefined)
@@ -221,7 +221,7 @@ class InMemoryCatalog(
}
}
- override protected def doDropTable(
+ override def dropTable(
db: String,
table: String,
ignoreIfNotExists: Boolean,
@@ -264,7 +264,7 @@ class InMemoryCatalog(
}
}
- override protected def doRenameTable(
+ override def renameTable(
db: String,
oldName: String,
newName: String): Unit = synchronized {
@@ -294,7 +294,7 @@ class InMemoryCatalog(
catalog(db).tables.remove(oldName)
}
- override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized {
+ override def alterTable(tableDefinition: CatalogTable): Unit = synchronized {
assert(tableDefinition.identifier.database.isDefined)
val db = tableDefinition.identifier.database.get
requireTableExists(db, tableDefinition.identifier.table)
@@ -303,7 +303,7 @@ class InMemoryCatalog(
catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition
}
- override def doAlterTableDataSchema(
+ override def alterTableDataSchema(
db: String,
table: String,
newDataSchema: StructType): Unit = synchronized {
@@ -313,7 +313,7 @@ class InMemoryCatalog(
catalog(db).tables(table).table = origTable.copy(schema = newSchema)
}
- override def doAlterTableStats(
+ override def alterTableStats(
db: String,
table: String,
stats: Option[CatalogStatistics]): Unit = synchronized {
@@ -564,24 +564,24 @@ class InMemoryCatalog(
// Functions
// --------------------------------------------------------------------------
- override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized {
+ override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
requireFunctionNotExists(db, func.identifier.funcName)
catalog(db).functions.put(func.identifier.funcName, func)
}
- override protected def doDropFunction(db: String, funcName: String): Unit = synchronized {
+ override def dropFunction(db: String, funcName: String): Unit = synchronized {
requireFunctionExists(db, funcName)
catalog(db).functions.remove(funcName)
}
- override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized {
+ override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
requireFunctionExists(db, func.identifier.funcName)
catalog(db).functions.put(func.identifier.funcName, func)
}
- override protected def doRenameFunction(
+ override def renameFunction(
db: String,
oldName: String,
newName: String): Unit = synchronized {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 4b119c75260a7..c390337c03ff5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -54,8 +54,8 @@ object SessionCatalog {
* This class must be thread-safe.
*/
class SessionCatalog(
- val externalCatalog: ExternalCatalog,
- globalTempViewManager: GlobalTempViewManager,
+ externalCatalogBuilder: () => ExternalCatalog,
+ globalTempViewManagerBuilder: () => GlobalTempViewManager,
functionRegistry: FunctionRegistry,
conf: SQLConf,
hadoopConf: Configuration,
@@ -70,8 +70,8 @@ class SessionCatalog(
functionRegistry: FunctionRegistry,
conf: SQLConf) {
this(
- externalCatalog,
- new GlobalTempViewManager("global_temp"),
+ () => externalCatalog,
+ () => new GlobalTempViewManager("global_temp"),
functionRegistry,
conf,
new Configuration(),
@@ -87,6 +87,9 @@ class SessionCatalog(
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))
}
+ lazy val externalCatalog = externalCatalogBuilder()
+ lazy val globalTempViewManager = globalTempViewManagerBuilder()
+
/** List of temporary views, mapping from table name to their logical plan. */
@GuardedBy("this")
protected val tempViews = new mutable.HashMap[String, LogicalPlan]
@@ -283,9 +286,13 @@ class SessionCatalog(
* Create a metastore table in the database specified in `tableDefinition`.
* If no such database is specified, create it in the current database.
*/
- def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
+ def createTable(
+ tableDefinition: CatalogTable,
+ ignoreIfExists: Boolean,
+ validateLocation: Boolean = true): Unit = {
val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
val table = formatTableName(tableDefinition.identifier.table)
+ val tableIdentifier = TableIdentifier(table, Some(db))
validateName(table)
val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined
@@ -295,15 +302,37 @@ class SessionCatalog(
makeQualifiedPath(tableDefinition.storage.locationUri.get)
tableDefinition.copy(
storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)),
- identifier = TableIdentifier(table, Some(db)))
+ identifier = tableIdentifier)
} else {
- tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
+ tableDefinition.copy(identifier = tableIdentifier)
}
requireDbExists(db)
+ if (tableExists(newTableDefinition.identifier)) {
+ if (!ignoreIfExists) {
+ throw new TableAlreadyExistsException(db = db, table = table)
+ }
+ } else if (validateLocation) {
+ validateTableLocation(newTableDefinition)
+ }
externalCatalog.createTable(newTableDefinition, ignoreIfExists)
}
+ def validateTableLocation(table: CatalogTable): Unit = {
+ // SPARK-19724: the default location of a managed table should be non-existent or empty.
+ if (table.tableType == CatalogTableType.MANAGED &&
+ !conf.allowCreatingManagedTableUsingNonemptyLocation) {
+ val tableLocation =
+ new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier)))
+ val fs = tableLocation.getFileSystem(hadoopConf)
+
+ if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) {
+ throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" +
+ s". The associated location('${tableLocation.toString}') already exists.")
+ }
+ }
+ }
+
/**
* Alter the metadata of an existing metastore table identified by `tableDefinition`.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 95b6fbb0cd61a..f3e67dc4e975c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -21,7 +21,9 @@ import java.net.URI
import java.util.Date
import scala.collection.mutable
+import scala.util.control.NonFatal
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
/**
@@ -361,7 +363,7 @@ object CatalogTable {
case class CatalogStatistics(
sizeInBytes: BigInt,
rowCount: Option[BigInt] = None,
- colStats: Map[String, ColumnStat] = Map.empty) {
+ colStats: Map[String, CatalogColumnStat] = Map.empty) {
/**
* Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based
@@ -369,7 +371,8 @@ case class CatalogStatistics(
*/
def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = {
if (cboEnabled && rowCount.isDefined) {
- val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _)))
+ val attrStats = AttributeMap(planOutput
+ .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType))))
// Estimate size as number of rows * row size.
val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats)
Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats)
@@ -387,6 +390,143 @@ case class CatalogStatistics(
}
}
+/**
+ * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore.
+ */
+case class CatalogColumnStat(
+ distinctCount: Option[BigInt] = None,
+ min: Option[String] = None,
+ max: Option[String] = None,
+ nullCount: Option[BigInt] = None,
+ avgLen: Option[Long] = None,
+ maxLen: Option[Long] = None,
+ histogram: Option[Histogram] = None) {
+
+ /**
+ * Returns a map from string to string that can be used to serialize the column stats.
+ * The key is the name of the column and name of the field (e.g. "colName.distinctCount"),
+ * and the value is the string representation for the value.
+ * min/max values are stored as Strings. They can be deserialized using
+ * [[CatalogColumnStat.fromExternalString]].
+ *
+ * As part of the protocol, the returned map always contains a key called "version".
+ * Any of the fields that are null (None) won't appear in the map.
+ */
+ def toMap(colName: String): Map[String, String] = {
+ val map = new scala.collection.mutable.HashMap[String, String]
+ map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1")
+ distinctCount.foreach { v =>
+ map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString)
+ }
+ nullCount.foreach { v =>
+ map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString)
+ }
+ avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) }
+ maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) }
+ min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) }
+ max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) }
+ histogram.foreach { h =>
+ map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h))
+ }
+ map.toMap
+ }
+
+ /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */
+ def toPlanStat(
+ colName: String,
+ dataType: DataType): ColumnStat =
+ ColumnStat(
+ distinctCount = distinctCount,
+ min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)),
+ max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)),
+ nullCount = nullCount,
+ avgLen = avgLen,
+ maxLen = maxLen,
+ histogram = histogram)
+}
+
+object CatalogColumnStat extends Logging {
+
+ // List of string keys used to serialize CatalogColumnStat
+ val KEY_VERSION = "version"
+ private val KEY_DISTINCT_COUNT = "distinctCount"
+ private val KEY_MIN_VALUE = "min"
+ private val KEY_MAX_VALUE = "max"
+ private val KEY_NULL_COUNT = "nullCount"
+ private val KEY_AVG_LEN = "avgLen"
+ private val KEY_MAX_LEN = "maxLen"
+ private val KEY_HISTOGRAM = "histogram"
+
+ /**
+ * Converts from string representation of data type to the corresponding Catalyst data type.
+ */
+ def fromExternalString(s: String, name: String, dataType: DataType): Any = {
+ dataType match {
+ case BooleanType => s.toBoolean
+ case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
+ case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
+ case ByteType => s.toByte
+ case ShortType => s.toShort
+ case IntegerType => s.toInt
+ case LongType => s.toLong
+ case FloatType => s.toFloat
+ case DoubleType => s.toDouble
+ case _: DecimalType => Decimal(s)
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case BinaryType | StringType => null
+ case _ =>
+ throw new AnalysisException("Column statistics deserialization is not supported for " +
+ s"column $name of data type: $dataType.")
+ }
+ }
+
+ /**
+ * Converts the given value from Catalyst data type to string representation of external
+ * data type.
+ */
+ def toExternalString(v: Any, colName: String, dataType: DataType): String = {
+ val externalValue = dataType match {
+ case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
+ case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
+ case BooleanType | _: IntegralType | FloatType | DoubleType => v
+ case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
+ // This version of Spark does not use min/max for binary/string types so we ignore it.
+ case _ =>
+ throw new AnalysisException("Column statistics serialization is not supported for " +
+ s"column $colName of data type: $dataType.")
+ }
+ externalValue.toString
+ }
+
+
+ /**
+ * Creates a [[CatalogColumnStat]] object from the given map.
+ * This is used to deserialize column stats from some external storage.
+ * The serialization side is defined in [[CatalogColumnStat.toMap]].
+ */
+ def fromMap(
+ table: String,
+ colName: String,
+ map: Map[String, String]): Option[CatalogColumnStat] = {
+
+ try {
+ Some(CatalogColumnStat(
+ distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)),
+ min = map.get(s"${colName}.${KEY_MIN_VALUE}"),
+ max = map.get(s"${colName}.${KEY_MAX_VALUE}"),
+ nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)),
+ avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong),
+ maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong),
+ histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize)
+ ))
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e)
+ None
+ }
+ }
+}
+
case class CatalogTableType private(name: String)
object CatalogTableType {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index efc2882f0a3d3..cbea3c017a265 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -128,7 +128,7 @@ object ExpressionEncoder {
case b: BoundReference if b == originalInputObject => newInputObject
})
- if (enc.flat) {
+ val serializerExpr = if (enc.flat) {
newSerializer.head
} else {
// For non-flat encoder, the input object is not top level anymore after being combined to
@@ -146,6 +146,7 @@ object ExpressionEncoder {
Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
If(nullCheck, Literal.create(null, struct.dataType), struct)
}
+ Alias(serializerExpr, s"_${index + 1}")()
}
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 789750fd408f2..3340789398f9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6a17a397b3ef2..df3ab05e02c76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -33,28 +34,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"
+ private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
+
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
- if (input.isNullAt(ordinal)) {
+ if (nullable && input.isNullAt(ordinal)) {
null
} else {
- dataType match {
- case BooleanType => input.getBoolean(ordinal)
- case ByteType => input.getByte(ordinal)
- case ShortType => input.getShort(ordinal)
- case IntegerType | DateType => input.getInt(ordinal)
- case LongType | TimestampType => input.getLong(ordinal)
- case FloatType => input.getFloat(ordinal)
- case DoubleType => input.getDouble(ordinal)
- case StringType => input.getUTF8String(ordinal)
- case BinaryType => input.getBinary(ordinal)
- case CalendarIntervalType => input.getInterval(ordinal)
- case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
- case t: StructType => input.getStruct(ordinal, t.size)
- case _: ArrayType => input.getArray(ordinal)
- case _: MapType => input.getMap(ordinal)
- case _ => input.get(ordinal, dataType)
- }
+ accessor(input, ordinal)
}
}
@@ -66,16 +53,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
- val javaType = ctx.javaType(dataType)
- val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
+ val javaType = CodeGenerator.javaType(dataType)
+ val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
- s"""
+ code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
- |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
+ |$javaType ${ev.value} = ${ev.isNull} ?
+ | ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
- ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
+ ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
index d848ba18356d3..7541f527a52a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions
* by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
+ * - Elements in [[In]] are reordered by `hashCode`.
*/
object Canonicalize {
def execute(e: Expression): Expression = {
@@ -85,6 +86,11 @@ object Canonicalize {
case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r)
case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r)
+ // order the list in the In operator
+ // In subqueries contain only one element of type ListQuery. So checking that the length > 1
+ // we are not reordering In subqueries.
+ case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode()))
+
case _ => e
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 79b051670e9e4..699ea53b5df0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
- ev.copy(code = eval.code +
- castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
+
+ ev.copy(code =
+ code"""
+ ${eval.code}
+ // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
+ ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
+ """)
}
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
@@ -669,7 +675,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
boolean $resultIsNull = $inputIsNull;
- ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
+ ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
@@ -685,7 +691,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
- |private UTF8String $funcName(${ctx.javaType(et)} element) {
+ |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
@@ -697,13 +703,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$buffer.append("[");
|if ($array.numElements() > 0) {
| if (!$array.isNullAt(0)) {
- | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
+ | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
| }
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
| $buffer.append(",");
| if (!$array.isNullAt($loopIndex)) {
| $buffer.append(" ");
- | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
+ | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
| }
| }
|}
@@ -723,7 +729,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val dataToStringCode = castToStringCode(dataType, ctx)
ctx.addNewFunction(funcName,
s"""
- |private UTF8String $funcName(${ctx.javaType(dataType)} data) {
+ |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) {
| UTF8String dataStr = null;
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
| return dataStr;
@@ -734,23 +740,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val keyToStringFunc = dataToStringFunc("keyToString", kt)
val valueToStringFunc = dataToStringFunc("valueToString", vt)
val loopIndex = ctx.freshName("loopIndex")
+ val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0")
+ val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0")
+ val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex)
+ val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex)
s"""
|$buffer.append("[");
|if ($map.numElements() > 0) {
- | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
+ | $buffer.append($keyToStringFunc($getMapFirstKey));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt(0)) {
| $buffer.append(" ");
- | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
+ | $buffer.append($valueToStringFunc($getMapFirstValue));
| }
| for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
| $buffer.append(", ");
- | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
+ | $buffer.append($keyToStringFunc($getMapKeyArray));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt($loopIndex)) {
| $buffer.append(" ");
- | $buffer.append($valueToStringFunc(
- | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
+ | $buffer.append($valueToStringFunc($getMapValueArray));
| }
| }
|}
@@ -773,7 +782,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
| // Append $i field into the string buffer
- | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
+ | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
@@ -1202,8 +1211,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
$values[$j] = null;
} else {
boolean $fromElementNull = false;
- ${ctx.javaType(fromType)} $fromElementPrim =
- ${ctx.getValue(c, fromType, j)};
+ ${CodeGenerator.javaType(fromType)} $fromElementPrim =
+ ${CodeGenerator.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
if ($toElementNull) {
@@ -1259,20 +1268,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
val toFieldNull = ctx.freshName("tfn")
- val fromType = ctx.javaType(from.fields(i).dataType)
+ val fromType = CodeGenerator.javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
$tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
- ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
+ ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
$tmpResult.setNullAt($i);
} else {
- ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
+ ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala
new file mode 100644
index 0000000000000..fb25e781e72e4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.codehaus.commons.compiler.CompileException
+import org.codehaus.janino.InternalCompilerException
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+/**
+ * Catches compile error during code generation.
+ */
+object CodegenError {
+ def unapply(throwable: Throwable): Option[Exception] = throwable match {
+ case e: InternalCompilerException => Some(e)
+ case e: CompileException => Some(e)
+ case _ => None
+ }
+}
+
+/**
+ * Defines values for `SQLConf` config of fallback mode. Use for test only.
+ */
+object CodegenObjectFactoryMode extends Enumeration {
+ val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value
+}
+
+/**
+ * A codegen object generator which creates objects with codegen path first. Once any compile
+ * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config
+ * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior.
+ */
+abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] {
+
+ def createObject(in: IN): OUT = {
+ // We are allowed to choose codegen-only or no-codegen modes if under tests.
+ val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE)
+ val fallbackMode = CodegenObjectFactoryMode.withName(config)
+
+ fallbackMode match {
+ case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting =>
+ createCodeGeneratedObject(in)
+ case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting =>
+ createInterpretedObject(in)
+ case _ =>
+ try {
+ createCodeGeneratedObject(in)
+ } catch {
+ case CodegenError(_) => createInterpretedObject(in)
+ }
+ }
+ }
+
+ protected def createCodeGeneratedObject(in: IN): OUT
+ protected def createInterpretedObject(in: IN): OUT
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 4568714933095..9b9fa41a47d0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -104,11 +105,13 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
- val eval = doGenCode(ctx, ExprCode("", isNull, value))
+ val eval = doGenCode(ctx, ExprCode(
+ JavaCode.isNullVariable(isNull),
+ JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
- if (eval.code.nonEmpty) {
+ if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
- eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
+ eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
@@ -117,31 +120,31 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
- if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
- val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
- val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
+ if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
+ val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
- eval.isNull = globalIsNull
+ eval.isNull = JavaCode.isNullGlobal(globalIsNull)
s"$globalIsNull = $localIsNull;"
} else {
""
}
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")
val funcName = ctx.freshName(nodeName)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
- | ${eval.code.trim}
+ | ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)
- eval.value = newValue
- eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
+ eval.value = JavaCode.variable(newValue, dataType)
+ eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
@@ -288,6 +291,7 @@ trait NonSQLExpression extends Expression {
final override def sql: String = {
transform {
case a: Attribute => new PrettyAttribute(a)
+ case a: Alias => PrettyAttribute(a.sql, a.dataType)
}.toString
}
}
@@ -328,6 +332,32 @@ trait Nondeterministic extends Expression {
protected def evalInternal(input: InternalRow): Any
}
+/**
+ * An expression that contains mutable state. A stateful expression is always non-deterministic
+ * because the results it produces during evaluation are not only dependent on the given input
+ * but also on its internal state.
+ *
+ * The state of the expressions is generally not exposed in the parameter list and this makes
+ * comparing stateful expressions problematic because similar stateful expressions (with the same
+ * parameter list) but with different internal state will be considered equal. This is especially
+ * problematic during tree transformations. In order to counter this the `fastEquals` method for
+ * stateful expressions only returns `true` for the same reference.
+ *
+ * A stateful expression should never be evaluated multiple times for a single row. This should
+ * only be a problem for interpreted execution. This can be prevented by creating fresh copies
+ * of the stateful expression before execution, these can be made using the `freshCopy` function.
+ */
+trait Stateful extends Nondeterministic {
+ /**
+ * Return a fresh uninitialized copy of the stateful expression.
+ */
+ def freshCopy(): Stateful
+
+ /**
+ * Only the same reference is considered equal.
+ */
+ override def fastEquals(other: TreeNode[_]): Boolean = this eq other
+}
/**
* A leaf expression, i.e. one without any child expressions.
@@ -408,18 +438,17 @@ abstract class UnaryExpression extends Expression {
if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${childGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- $resultCode""", isNull = "false")
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ $resultCode""", isNull = FalseLiteral)
}
}
}
@@ -508,18 +537,17 @@ abstract class BinaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- $resultCode""", isNull = "false")
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ $resultCode""", isNull = FalseLiteral)
}
}
}
@@ -652,18 +680,17 @@ abstract class TernaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- $resultCode""", isNull = "false")
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ $resultCode""", isNull = FalseLiteral)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
new file mode 100644
index 0000000000000..55a5bd380859e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types.{UserDefinedType, _}
+import org.apache.spark.unsafe.Platform
+
+/**
+ * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer
+ * should copy the row if it is being buffered. This class is not thread safe.
+ *
+ * @param expressions that produces the resulting fields. These expressions must be bound
+ * to a schema.
+ */
+class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection {
+ import InterpretedUnsafeProjection._
+
+ /** Number of (top level) fields in the resulting row. */
+ private[this] val numFields = expressions.length
+
+ /** Array that expression results. */
+ private[this] val values = new Array[Any](numFields)
+
+ /** The row representing the expression results. */
+ private[this] val intermediate = new GenericInternalRow(values)
+
+ /* The row writer for UnsafeRow result */
+ private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32)
+
+ /** The writer that writes the intermediate result to the result row. */
+ private[this] val writer: InternalRow => Unit = {
+ val baseWriter = generateStructWriter(
+ rowWriter,
+ expressions.map(e => StructField("", e.dataType, e.nullable)))
+ if (!expressions.exists(_.nullable)) {
+ // No nullable fields. The top-level null bit mask will always be zeroed out.
+ baseWriter
+ } else {
+ // Zero out the null bit mask before we write the row.
+ row => {
+ rowWriter.zeroOutNullBytes()
+ baseWriter(row)
+ }
+ }
+ }
+
+ override def initialize(partitionIndex: Int): Unit = {
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+
+ override def apply(row: InternalRow): UnsafeRow = {
+ // Put the expression results in the intermediate row.
+ var i = 0
+ while (i < numFields) {
+ values(i) = expressions(i).eval(row)
+ i += 1
+ }
+
+ // Write the intermediate row to an unsafe row.
+ rowWriter.reset()
+ writer(intermediate)
+ rowWriter.getRow()
+ }
+}
+
+/**
+ * Helper functions for creating an [[InterpretedUnsafeProjection]].
+ */
+object InterpretedUnsafeProjection {
+ /**
+ * Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
+ */
+ def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ // We need to make sure that we do not reuse stateful expressions.
+ val cleanedExpressions = exprs.map(_.transform {
+ case s: Stateful => s.freshCopy()
+ })
+ new InterpretedUnsafeProjection(cleanedExpressions.toArray)
+ }
+
+ /**
+ * Generate a struct writer function. The generated function writes an [[InternalRow]] to the
+ * given buffer using the given [[UnsafeRowWriter]].
+ */
+ private def generateStructWriter(
+ rowWriter: UnsafeRowWriter,
+ fields: Array[StructField]): InternalRow => Unit = {
+ val numFields = fields.length
+
+ // Create field writers.
+ val fieldWriters = fields.map { field =>
+ generateFieldWriter(rowWriter, field.dataType, field.nullable)
+ }
+ // Create basic writer.
+ row => {
+ var i = 0
+ while (i < numFields) {
+ fieldWriters(i).apply(row, i)
+ i += 1
+ }
+ }
+ }
+
+ /**
+ * Generate a writer function for a struct field, array element, map key or map value. The
+ * generated function writes the element at an index in a [[SpecializedGetters]] object (row
+ * or array) to the given buffer using the given [[UnsafeWriter]].
+ */
+ private def generateFieldWriter(
+ writer: UnsafeWriter,
+ dt: DataType,
+ nullable: Boolean): (SpecializedGetters, Int) => Unit = {
+
+ // Create the the basic writer.
+ val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match {
+ case BooleanType =>
+ (v, i) => writer.write(i, v.getBoolean(i))
+
+ case ByteType =>
+ (v, i) => writer.write(i, v.getByte(i))
+
+ case ShortType =>
+ (v, i) => writer.write(i, v.getShort(i))
+
+ case IntegerType | DateType =>
+ (v, i) => writer.write(i, v.getInt(i))
+
+ case LongType | TimestampType =>
+ (v, i) => writer.write(i, v.getLong(i))
+
+ case FloatType =>
+ (v, i) => writer.write(i, v.getFloat(i))
+
+ case DoubleType =>
+ (v, i) => writer.write(i, v.getDouble(i))
+
+ case DecimalType.Fixed(precision, scale) =>
+ (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale)
+
+ case CalendarIntervalType =>
+ (v, i) => writer.write(i, v.getInterval(i))
+
+ case BinaryType =>
+ (v, i) => writer.write(i, v.getBinary(i))
+
+ case StringType =>
+ (v, i) => writer.write(i, v.getUTF8String(i))
+
+ case StructType(fields) =>
+ val numFields = fields.length
+ val rowWriter = new UnsafeRowWriter(writer, numFields)
+ val structWriter = generateStructWriter(rowWriter, fields)
+ (v, i) => {
+ v.getStruct(i, fields.length) match {
+ case row: UnsafeRow =>
+ writer.write(i, row)
+ case row =>
+ val previousCursor = writer.cursor()
+ // Nested struct. We don't know where this will start because a row can be
+ // variable length, so we need to update the offsets and zero out the bit mask.
+ rowWriter.resetRowWriter()
+ structWriter.apply(row)
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
+ }
+ }
+
+ case ArrayType(elementType, containsNull) =>
+ val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType))
+ val elementWriter = generateFieldWriter(
+ arrayWriter,
+ elementType,
+ containsNull)
+ (v, i) => {
+ val previousCursor = writer.cursor()
+ writeArray(arrayWriter, elementWriter, v.getArray(i))
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
+ }
+
+ case MapType(keyType, valueType, valueContainsNull) =>
+ val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType))
+ val keyWriter = generateFieldWriter(
+ keyArrayWriter,
+ keyType,
+ nullable = false)
+ val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType))
+ val valueWriter = generateFieldWriter(
+ valueArrayWriter,
+ valueType,
+ valueContainsNull)
+ (v, i) => {
+ v.getMap(i) match {
+ case map: UnsafeMapData =>
+ writer.write(i, map)
+ case map =>
+ val previousCursor = writer.cursor()
+
+ // preserve 8 bytes to write the key array numBytes later.
+ valueArrayWriter.grow(8)
+ valueArrayWriter.increaseCursor(8)
+
+ // Write the keys and write the numBytes of key array into the first 8 bytes.
+ writeArray(keyArrayWriter, keyWriter, map.keyArray())
+ Platform.putLong(
+ valueArrayWriter.getBuffer,
+ previousCursor,
+ valueArrayWriter.cursor - previousCursor - 8
+ )
+
+ // Write the values.
+ writeArray(valueArrayWriter, valueWriter, map.valueArray())
+ writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor)
+ }
+ }
+
+ case udt: UserDefinedType[_] =>
+ generateFieldWriter(writer, udt.sqlType, nullable)
+
+ case NullType =>
+ (_, _) => {}
+
+ case _ =>
+ throw new SparkException(s"Unsupported data type $dt")
+ }
+
+ // Always wrap the writer with a null safe version.
+ dt match {
+ case _: UserDefinedType[_] =>
+ // The null wrapper depends on the sql type and not on the UDT.
+ unsafeWriter
+ case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS =>
+ // We can't call setNullAt() for DecimalType with precision larger than 18, we call write
+ // directly. We can use the unwrapped writer directly.
+ unsafeWriter
+ case BooleanType | ByteType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull1Bytes(i)
+ }
+ }
+ case ShortType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull2Bytes(i)
+ }
+ }
+ case IntegerType | DateType | FloatType =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull4Bytes(i)
+ }
+ }
+ case _ =>
+ (v, i) => {
+ if (!v.isNullAt(i)) {
+ unsafeWriter(v, i)
+ } else {
+ writer.setNull8Bytes(i)
+ }
+ }
+ }
+ }
+
+ /**
+ * Get the number of bytes elements of a data type will occupy in the fixed part of an
+ * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an
+ * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length
+ * portion of the array object. Primitives take up to 8 bytes, depending on the size of the
+ * underlying data type.
+ */
+ private def getElementSize(dataType: DataType): Int = dataType match {
+ case NullType | StringType | BinaryType | CalendarIntervalType |
+ _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
+ case _ => dataType.defaultSize
+ }
+
+ /**
+ * Write an array to the buffer. If the array is already in serialized form (an instance of
+ * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element
+ * copy.
+ */
+ private def writeArray(
+ arrayWriter: UnsafeArrayWriter,
+ elementWriter: (SpecializedGetters, Int) => Unit,
+ array: ArrayData): Unit = array match {
+ case unsafe: UnsafeArrayData =>
+ arrayWriter.write(unsafe)
+ case _ =>
+ val numElements = array.numElements()
+ arrayWriter.initialize(numElements)
+ var i = 0
+ while (i < numElements) {
+ elementWriter.apply(array, i)
+ i += 1
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 11fb579dfa88c..f1da592a76845 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}
/**
@@ -38,8 +39,9 @@ import org.apache.spark.sql.types.{DataType, LongType}
puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number
within each partition. The assumption is that the data frame has less than 1 billion
partitions, and each partition has less than 8 billion records.
+ The function is non-deterministic because its result depends on partition IDs.
""")
-case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic {
+case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
@@ -65,18 +67,20 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
+ val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
- ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
+ ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
- ev.copy(code = s"""
- final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
- $countTerm++;""", isNull = "false")
+ ev.copy(code = code"""
+ final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
+ $countTerm++;""", isNull = FalseLiteral)
}
override def prettyName: String = "monotonically_increasing_id"
override def sql: String = s"$prettyName()"
+
+ override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 64b94f0a2c103..6493f09100577 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -108,7 +108,31 @@ abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow
}
-object UnsafeProjection {
+/**
+ * The factory object for `UnsafeProjection`.
+ */
+object UnsafeProjection
+ extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {
+
+ override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
+ GenerateUnsafeProjection.generate(in)
+ }
+
+ override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
+ InterpretedUnsafeProjection.createProjection(in)
+ }
+
+ protected def toBoundExprs(
+ exprs: Seq[Expression],
+ inputSchema: Seq[Attribute]): Seq[Expression] = {
+ exprs.map(BindReferences.bindReference(_, inputSchema))
+ }
+
+ protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = {
+ exprs.map(_ transform {
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ })
+ }
/**
* Returns an UnsafeProjection for given StructType.
@@ -127,13 +151,10 @@ object UnsafeProjection {
}
/**
- * Returns an UnsafeProjection for given sequence of Expressions (bounded).
+ * Returns an UnsafeProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
- val unsafeExprs = exprs.map(_ transform {
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
- })
- GenerateUnsafeProjection.generate(unsafeExprs)
+ createObject(toUnsafeExprs(exprs))
}
def create(expr: Expression): UnsafeProjection = create(Seq(expr))
@@ -143,22 +164,24 @@ object UnsafeProjection {
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
- create(exprs.map(BindReferences.bindReference(_, inputSchema)))
+ create(toBoundExprs(exprs, inputSchema))
}
/**
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
- * TODO: refactor the plumbing and clean this up.
+ * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
+ * when fallbacking to interpreted execution, it is not supported.
*/
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
- val e = exprs.map(BindReferences.bindReference(_, inputSchema))
- .map(_ transform {
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
- })
- GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled)
+ val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
+ try {
+ GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
+ } catch {
+ case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index efd664dde725a..6530b176968f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -34,10 +34,14 @@ object PythonUDF {
e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
}
- def isGroupAggPandasUDF(e: Expression): Boolean = {
+ def isGroupedAggPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}
+
+ // This is currently same as GroupedAggPandasUDF, but we might support new types in the future,
+ // e.g, N -> N transform.
+ def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 388ef42883ad3..3e7ca88249737 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
/**
@@ -49,6 +50,17 @@ case class ScalaUDF(
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
+ // The constructor for SPARK 2.1 and 2.2
+ def this(
+ function: AnyRef,
+ dataType: DataType,
+ children: Seq[Expression],
+ inputTypes: Seq[DataType],
+ udfName: Option[String]) = {
+ this(
+ function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true)
+ }
+
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
override def toString: String =
@@ -1007,24 +1019,25 @@ case class ScalaUDF(
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
+ val boxedType = CodeGenerator.boxedType(dataType)
val callFunc =
s"""
- |${ctx.boxedType(dataType)} $resultTerm = null;
+ |$boxedType $resultTerm = null;
|try {
- | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
+ | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|}
""".stripMargin
ev.copy(code =
- s"""
+ code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index ff7c98f714905..76a881146a146 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
abstract sealed class SortDirection {
@@ -147,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
(!child.isAscending && child.nullOrdering == NullsLast)
}
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+ private lazy val calcPrefix: Any => Long = child.child.dataType match {
+ case BooleanType => (raw) =>
+ if (raw.asInstanceOf[Boolean]) 1 else 0
+ case DateType | TimestampType | _: IntegralType => (raw) =>
+ raw.asInstanceOf[java.lang.Number].longValue()
+ case FloatType | DoubleType => (raw) => {
+ val dVal = raw.asInstanceOf[java.lang.Number].doubleValue()
+ DoublePrefixComparator.computePrefix(dVal)
+ }
+ case StringType => (raw) =>
+ StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String])
+ case BinaryType => (raw) =>
+ BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]])
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ _.asInstanceOf[Decimal].toUnscaledLong
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ val p = Decimal.MAX_LONG_DIGITS
+ val s = p - (dt.precision - dt.scale)
+ (raw) => {
+ val value = raw.asInstanceOf[Decimal]
+ if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue
+ }
+ case dt: DecimalType => (raw) =>
+ DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble)
+ case _ => (Any) => 0L
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.child.eval(input)
+ if (value == null) {
+ null
+ } else {
+ calcPrefix(value)
+ }
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childCode = child.child.genCode(ctx)
@@ -181,7 +217,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}
ev.copy(code = childCode.code +
- s"""
+ code"""
|long ${ev.value} = 0L;
|boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index a160b9b275290..9856b37e53fbc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -44,8 +45,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
- ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
+ ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
+ ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
+ isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 9a9f579b37f58..84e38a8b2711e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -22,7 +22,8 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -164,8 +165,8 @@ case class PreciseTimestampConversion(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
- s"""boolean ${ev.isNull} = ${eval.isNull};
- |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
+ code"""boolean ${ev.isNull} = ${eval.isNull};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index a45854a3b5146..f1bbbdabb41f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -206,27 +206,15 @@ object ApproximatePercentile {
* with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
*
* @param summaries underlying probabilistic data structure [[QuantileSummaries]].
- * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the
- * underlying quantileSummaries is compressed.
*/
- class PercentileDigest(
- private var summaries: QuantileSummaries,
- private var isCompressed: Boolean) {
-
- // Trigger compression if the QuantileSummaries's buffer length exceeds
- // compressThresHoldBufferLength. The buffer length can be get by
- // quantileSummaries.sampled.length
- private[this] final val compressThresHoldBufferLength: Int = {
- // Max buffer length after compression.
- val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2
- // A safe upper bound for buffer length before compression
- maxBufferLengthAfterCompression * 2
- }
+ class PercentileDigest(private var summaries: QuantileSummaries) {
def this(relativeError: Double) = {
- this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true)
+ this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true))
}
+ private[sql] def isCompressed: Boolean = summaries.compressed
+
/** Returns compressed object of [[QuantileSummaries]] */
def quantileSummaries: QuantileSummaries = {
if (!isCompressed) compress()
@@ -236,14 +224,6 @@ object ApproximatePercentile {
/** Insert an observation value into the PercentileDigest data structure. */
def add(value: Double): Unit = {
summaries = summaries.insert(value)
- // The result of QuantileSummaries.insert is un-compressed
- isCompressed = false
-
- // Currently, QuantileSummaries ignores the construction parameter compressThresHold,
- // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here
- // to make sure QuantileSummaries doesn't occupy infinite memory.
- // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold
- if (summaries.sampled.length >= compressThresHoldBufferLength) compress()
}
/** In-place merges in another PercentileDigest. */
@@ -280,7 +260,6 @@ object ApproximatePercentile {
private final def compress(): Unit = {
summaries = summaries.compress()
- isCompressed = true
}
}
@@ -335,8 +314,8 @@ object ApproximatePercentile {
sampled(i) = Stats(value, g, delta)
i += 1
}
- val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count)
- new PercentileDigest(summary, isCompressed = true)
+ val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true)
+ new PercentileDigest(summary)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 708bdbfc36058..a133bc2361eb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
-case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
-
- override def prettyName: String = "avg"
-
- override def children: Seq[Expression] = child :: Nil
+abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
override def nullable: Boolean = true
-
// Return data type.
override def dataType: DataType = resultType
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
-
- override def checkInputDataTypes(): TypeCheckResult =
- TypeUtils.checkForNumericExpr(child.dataType, "function average")
-
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
@@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
/* count = */ Literal(0L)
)
- override lazy val updateExpressions = Seq(
- /* sum = */
- Add(
- sum,
- Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
- /* count = */ If(IsNull(child), count, count + 1L)
- )
-
override lazy val mergeExpressions = Seq(
/* sum = */ sum.left + sum.right,
/* count = */ count.left + count.right
@@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
+
+ protected def updateExpressionsDef: Seq[Expression] = Seq(
+ /* sum = */
+ Add(
+ sum,
+ Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+ /* count = */ If(IsNull(child), count, count + 1L)
+ )
+
+ override lazy val updateExpressions = updateExpressionsDef
+}
+
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
+case class Average(child: Expression)
+ extends AverageLike(child) with ImplicitCastInputTypes {
+
+ override def prettyName: String = "avg"
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+ override def checkInputDataTypes(): TypeCheckResult =
+ TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 572d29caf5bc9..6bbb083f1e18e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression)
override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))
- override val updateExpressions: Seq[Expression] = {
- val newN = n + Literal(1.0)
- val delta = child - avg
- val deltaN = delta / newN
- val newAvg = avg + deltaN
- val newM2 = m2 + delta * (delta - deltaN)
-
- val delta2 = delta * delta
- val deltaN2 = deltaN * deltaN
- val newM3 = if (momentOrder >= 3) {
- m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
- } else {
- Literal(0.0)
- }
- val newM4 = if (momentOrder >= 4) {
- m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
- delta * (delta * delta2 - deltaN * deltaN2)
- } else {
- Literal(0.0)
- }
-
- trimHigherOrder(Seq(
- If(IsNull(child), n, newN),
- If(IsNull(child), avg, newAvg),
- If(IsNull(child), m2, newM2),
- If(IsNull(child), m3, newM3),
- If(IsNull(child), m4, newM4)
- ))
- }
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
override val mergeExpressions: Seq[Expression] = {
@@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression)
trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
}
+
+ protected def updateExpressionsDef: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val delta = child - avg
+ val deltaN = delta / newN
+ val newAvg = avg + deltaN
+ val newM2 = m2 + delta * (delta - deltaN)
+
+ val delta2 = delta * delta
+ val deltaN2 = deltaN * deltaN
+ val newM3 = if (momentOrder >= 3) {
+ m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+ } else {
+ Literal(0.0)
+ }
+ val newM4 = if (momentOrder >= 4) {
+ m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+ delta * (delta * delta2 - deltaN * deltaN2)
+ } else {
+ Literal(0.0)
+ }
+
+ trimHigherOrder(Seq(
+ If(IsNull(child), n, newN),
+ If(IsNull(child), avg, newAvg),
+ If(IsNull(child), m2, newM2),
+ If(IsNull(child), m3, newM3),
+ If(IsNull(child), m4, newM4)
+ ))
+ }
}
// Compute the population standard deviation of a column
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 95a4a0d5af634..3cdef72c1f2c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
- * Compute Pearson correlation between two expressions.
+ * Base class for computing Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
-// scalastyle:on line.size.limit
-case class Corr(x: Expression, y: Expression)
+abstract class PearsonCorrelation(x: Expression, y: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {
override def children: Seq[Expression] = Seq(x, y)
@@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression)
override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
- override val updateExpressions: Seq[Expression] = {
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
+
+ override val mergeExpressions: Seq[Expression] = {
+ val n1 = n.left
+ val n2 = n.right
+ val newN = n1 + n2
+ val dx = xAvg.right - xAvg.left
+ val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+ val dy = yAvg.right - yAvg.left
+ val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+ val newXAvg = xAvg.left + dxN * n2
+ val newYAvg = yAvg.left + dyN * n2
+ val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+ val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+ val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+ Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
+ }
+
+ protected def updateExpressionsDef: Seq[Expression] = {
val newN = n + Literal(1.0)
val dx = x - xAvg
val dxN = dx / newN
@@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression)
If(isNull, yMk, newYMk)
)
}
+}
- override val mergeExpressions: Seq[Expression] = {
-
- val n1 = n.left
- val n2 = n.right
- val newN = n1 + n2
- val dx = xAvg.right - xAvg.left
- val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
- val dy = yAvg.right - yAvg.left
- val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
- val newXAvg = xAvg.left + dxN * n2
- val newYAvg = yAvg.left + dyN * n2
- val newCk = ck.left + ck.right + dx * dyN * n1 * n2
- val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
- val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
- Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
- }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
+// scalastyle:on line.size.limit
+case class Corr(x: Expression, y: Expression)
+ extends PearsonCorrelation(x, y) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1990f2f2f0722..40582d0abd762 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = """
- _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
-
- _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
-
- _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null.
- """)
-// scalastyle:on line.size.limit
-case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
-
+/**
+ * Base class for all counting aggregators.
+ */
+abstract class CountLike extends DeclarativeAggregate {
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
- private lazy val count = AttributeReference("count", LongType, nullable = false)()
+ protected lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
@@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
/* count = */ Literal(0L)
)
+ override lazy val mergeExpressions = Seq(
+ /* count = */ count.left + count.right
+ )
+
+ override lazy val evaluateExpression = count
+
+ override def defaultResult: Option[Literal] = Option(Literal(0L))
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
+
+ _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
+
+ _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null.
+ """)
+// scalastyle:on line.size.limit
+case class Count(children: Seq[Expression]) extends CountLike {
+
override lazy val updateExpressions = {
val nullableChildren = children.filter(_.nullable)
if (nullableChildren.isEmpty) {
@@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
)
}
}
-
- override lazy val mergeExpressions = Seq(
- /* count = */ count.left + count.right
- )
-
- override lazy val evaluateExpression = count
-
- override def defaultResult: Option[Literal] = Option(Literal(0L))
}
object Count {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index fc6c34baafdd1..72a7c62b328ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression)
override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
- override lazy val updateExpressions: Seq[Expression] = {
- val newN = n + Literal(1.0)
- val dx = x - xAvg
- val dy = y - yAvg
- val dyN = dy / newN
- val newXAvg = xAvg + dx / newN
- val newYAvg = yAvg + dyN
- val newCk = ck + dx * (y - newYAvg)
-
- val isNull = IsNull(x) || IsNull(y)
- Seq(
- If(isNull, n, newN),
- If(isNull, xAvg, newXAvg),
- If(isNull, yAvg, newYAvg),
- If(isNull, ck, newCk)
- )
- }
+ override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
override val mergeExpressions: Seq[Expression] = {
@@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression)
Seq(newN, newXAvg, newYAvg, newCk)
}
+
+ protected def updateExpressionsDef: Seq[Expression] = {
+ val newN = n + Literal(1.0)
+ val dx = x - xAvg
+ val dy = y - yAvg
+ val dyN = dy / newN
+ val newXAvg = xAvg + dx / newN
+ val newYAvg = yAvg + dyN
+ val newCk = ck + dx * (y - newYAvg)
+
+ val isNull = IsNull(x) || IsNull(y)
+ Seq(
+ If(isNull, n, newN),
+ If(isNull, xAvg, newXAvg),
+ If(isNull, yAvg, newYAvg),
+ If(isNull, ck, newCk)
+ )
+ }
}
@ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
new file mode 100644
index 0000000000000..d8f4505588ff2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{AbstractDataType, DoubleType}
+
+/**
+ * Base trait for all regression functions.
+ */
+trait RegrLike extends AggregateFunction with ImplicitCastInputTypes {
+ def y: Expression
+ def x: Expression
+
+ override def children: Seq[Expression] = Seq(y, x)
+ override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+ protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = {
+ assert(aggBufferAttributes.length == exprs.length)
+ val nullableChildren = children.filter(_.nullable)
+ if (nullableChildren.isEmpty) {
+ exprs
+ } else {
+ exprs.zip(aggBufferAttributes).map { case (e, a) =>
+ If(nullableChildren.map(IsNull).reduce(Or), a, e)
+ }
+ }
+ }
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the number of non-null pairs.",
+ since = "2.4.0")
+case class RegrCount(y: Expression, x: Expression)
+ extends CountLike with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L))
+
+ override def prettyName: String = "regr_count"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrSXX(y: Expression, x: Expression)
+ extends CentralMomentAgg(x) with RegrLike {
+
+ override protected def momentOrder = 2
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+ }
+
+ override def prettyName: String = "regr_sxx"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrSYY(y: Expression, x: Expression)
+ extends CentralMomentAgg(y) with RegrLike {
+
+ override protected def momentOrder = 2
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+ }
+
+ override def prettyName: String = "regr_syy"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrAvgX(y: Expression, x: Expression)
+ extends AverageLike(x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override def prettyName: String = "regr_avgx"
+}
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+case class RegrAvgY(y: Expression, x: Expression)
+ extends AverageLike(y) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override def prettyName: String = "regr_avgy"
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSXY(y: Expression, x: Expression)
+ extends Covariance(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0), Literal.create(null, DoubleType), ck)
+ }
+
+ override def prettyName: String = "regr_sxy"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSlope(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk)
+ }
+
+ override def prettyName: String = "regr_slope"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrR2(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+ If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk))
+ }
+
+ override def prettyName: String = "regr_r2"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.",
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrIntercept(y: Expression, x: Expression)
+ extends PearsonCorrelation(y, x) with RegrLike {
+
+ override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+ override val evaluateExpression: Expression = {
+ If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+ xAvg - (ck / yMk) * yAvg)
+ }
+
+ override def prettyName: String = "regr_intercept"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 8bb14598a6d7b..fe91e520169b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -43,16 +44,16 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
- case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
+ case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
// for example, we could not write --9223372036854775808L in code
s"""
- ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
- ${ev.value} = (${ctx.javaType(dt)})(-($originValue));
+ ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
+ ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
"""})
- case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
+ case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
protected override def nullSafeEval(input: Any): Any = {
@@ -104,10 +105,10 @@ case class Abs(child: Expression)
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
- case dt: DecimalType =>
+ case _: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
+ defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
@@ -117,19 +118,25 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
override def dataType: DataType = left.dataType
- override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
+ override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
+ /** Name of the function for this expression on a [[CalendarInterval]] type. */
+ def calendarIntervalMethod: String =
+ sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
- case dt: DecimalType =>
+ case _: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
+ case CalendarIntervalType =>
+ defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
+ (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
@@ -152,6 +159,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
+ override def decimalMethod: String = "$plus"
+
+ override def calendarIntervalMethod: String = "add"
+
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -161,18 +172,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
- case dt: DecimalType =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
- case ByteType | ShortType =>
- defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
- case CalendarIntervalType =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
- case _ =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
- }
}
@ExpressionDescription(
@@ -188,6 +187,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
override def symbol: String = "-"
+ override def decimalMethod: String = "$minus"
+
+ override def calendarIntervalMethod: String = "subtract"
+
private lazy val numeric = TypeUtils.getNumeric(dataType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -197,18 +200,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
- case dt: DecimalType =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
- case ByteType | ShortType =>
- defineCodeGen(ctx, ev,
- (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
- case CalendarIntervalType =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
- case _ =>
- defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
- }
}
@ExpressionDescription(
@@ -230,30 +221,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.",
- examples = """
- Examples:
- > SELECT 3 _FUNC_ 2;
- 1.5
- > SELECT 2L _FUNC_ 2L;
- 1.0
- """)
-// scalastyle:on line.size.limit
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
-
- override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
+// Common base trait for Divide and Remainder, since these two classes are almost identical
+trait DivModLike extends BinaryArithmetic {
- override def symbol: String = "/"
- override def decimalMethod: String = "$div"
override def nullable: Boolean = true
- private lazy val div: (Any, Any) => Any = dataType match {
- case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
- }
-
- override def eval(input: InternalRow): Any = {
+ final override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
null
@@ -262,13 +235,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
if (input1 == null) {
null
} else {
- div(input1, input2)
+ evalOperation(input1, input2)
}
}
}
+ def evalOperation(left: Any, right: Any): Any
+
/**
- * Special case handling due to division by 0 => null.
+ * Special case handling due to division/remainder by 0 => null.
*/
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval1 = left.genCode(ctx)
@@ -278,28 +253,28 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"${eval2.value} == 0"
}
- val javaType = ctx.javaType(dataType)
- val divide = if (dataType.isInstanceOf[DecimalType]) {
+ val javaType = CodeGenerator.javaType(dataType)
+ val operation = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
${eval1.code}
- ${ev.value} = $divide;
+ ${ev.value} = $operation;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@@ -307,13 +282,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
if (${eval1.isNull}) {
${ev.isNull} = true;
} else {
- ${ev.value} = $divide;
+ ${ev.value} = $operation;
}
}""")
}
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.",
+ examples = """
+ Examples:
+ > SELECT 3 _FUNC_ 2;
+ 1.5
+ > SELECT 2L _FUNC_ 2L;
+ 1.0
+ """)
+// scalastyle:on line.size.limit
+case class Divide(left: Expression, right: Expression) extends DivModLike {
+
+ override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
+
+ override def symbol: String = "/"
+ override def decimalMethod: String = "$div"
+
+ private lazy val div: (Any, Any) => Any = dataType match {
+ case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
+ }
+
+ override def evalOperation(left: Any, right: Any): Any = div(left, right)
+}
+
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.",
examples = """
@@ -323,82 +323,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
> SELECT MOD(2, 1.8);
0.2
""")
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Remainder(left: Expression, right: Expression) extends DivModLike {
override def inputType: AbstractDataType = NumericType
override def symbol: String = "%"
override def decimalMethod: String = "remainder"
- override def nullable: Boolean = true
- private lazy val integral = dataType match {
- case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
- case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
- }
-
- override def eval(input: InternalRow): Any = {
- val input2 = right.eval(input)
- if (input2 == null || input2 == 0) {
- null
- } else {
- val input1 = left.eval(input)
- if (input1 == null) {
- null
- } else {
- input1 match {
- case d: Double => d % input2.asInstanceOf[java.lang.Double]
- case f: Float => f % input2.asInstanceOf[java.lang.Float]
- case _ => integral.rem(input1, input2)
- }
- }
- }
+ private lazy val mod: (Any, Any) => Any = dataType match {
+ // special cases to make float/double primitive types faster
+ case DoubleType =>
+ (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double]
+ case FloatType =>
+ (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float]
+
+ // catch-all cases
+ case i: IntegralType =>
+ val integral = i.integral.asInstanceOf[Integral[Any]]
+ (left, right) => integral.rem(left, right)
+ case i: FractionalType => // should only be DecimalType for now
+ val integral = i.asIntegral.asInstanceOf[Integral[Any]]
+ (left, right) => integral.rem(left, right)
}
- /**
- * Special case handling for x % 0 ==> null.
- */
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val eval1 = left.genCode(ctx)
- val eval2 = right.genCode(ctx)
- val isZero = if (dataType.isInstanceOf[DecimalType]) {
- s"${eval2.value}.isZero()"
- } else {
- s"${eval2.value} == 0"
- }
- val javaType = ctx.javaType(dataType)
- val remainder = if (dataType.isInstanceOf[DecimalType]) {
- s"${eval1.value}.$decimalMethod(${eval2.value})"
- } else {
- s"($javaType)(${eval1.value} $symbol ${eval2.value})"
- }
- if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if ($isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- ${ev.value} = $remainder;
- }""")
- } else {
- ev.copy(code = s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = $remainder;
- }
- }""")
- }
- }
+ override def evalOperation(left: Any, right: Any): Any = mod(left, right)
}
@ExpressionDescription(
@@ -416,7 +364,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "pmod"
- protected def checkTypesInternal(t: DataType) =
+ protected def checkTypesInternal(t: DataType): TypeCheckResult =
TypeUtils.checkForNumericExpr(t, "pmod")
override def inputType: AbstractDataType = NumericType
@@ -454,13 +402,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
s"${eval2.value} == 0"
}
val remainder = ctx.freshName("remainder")
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val result = dataType match {
case DecimalType.Fixed(_, _) =>
val decimalAdd = "$plus"
s"""
- ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value});
+ $javaType $remainder = ${eval1.value}.remainder(${eval2.value});
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value});
} else {
@@ -470,17 +418,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
- ${ctx.javaType(dataType)} $remainder =
- (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value});
+ $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
if ($remainder < 0) {
- ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value});
+ ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
} else {
${ev.value}=$remainder;
}
"""
case _ =>
s"""
- ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value};
+ $javaType $remainder = ${eval1.value} % ${eval2.value};
if ($remainder < 0) {
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
} else {
@@ -490,10 +437,10 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@@ -501,10 +448,10 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
$result
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@@ -602,19 +549,15 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
- |if (!${eval.isNull} && (${ev.isNull} ||
- | ${ctx.genGreater(dataType, ev.value, eval.value)})) {
- | ${ev.isNull} = false;
- | ${ev.value} = ${eval.value};
- |}
+ |${ctx.reassignIfSmaller(dataType, ev, eval)}
""".stripMargin
)
- val resultType = ctx.javaType(dataType)
+ val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "least",
@@ -627,9 +570,9 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}
@@ -681,19 +624,15 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
- ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
- |if (!${eval.isNull} && (${ev.isNull} ||
- | ${ctx.genGreater(dataType, eval.value, ev.value)})) {
- | ${ev.isNull} = false;
- | ${ev.value} = ${eval.value};
- |}
+ |${ctx.reassignIfGreater(dataType, ev, eval)}
""".stripMargin
)
- val resultType = ctx.javaType(dataType)
+ val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "greatest",
@@ -706,9 +645,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 173481f06a716..cc24e397cc14a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
+ defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)")
}
protected override def nullSafeEval(input: Any): Any = not(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4dcbb702893da..66315e5906253 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -38,10 +38,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types._
import org.apache.spark.util.{ParentClassLoader, Utils}
@@ -56,7 +58,21 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
-case class ExprCode(var code: String, var isNull: String, var value: String)
+case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
+
+object ExprCode {
+ def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
+ ExprCode(code = EmptyBlock, isNull, value)
+ }
+
+ def forNullValue(dataType: DataType): ExprCode = {
+ ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
+ }
+
+ def forNonNullValue(value: ExprValue): ExprCode = {
+ ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
+ }
+}
/**
* State used for subexpression elimination.
@@ -66,7 +82,7 @@ case class ExprCode(var code: String, var isNull: String, var value: String)
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
*/
-case class SubExprEliminationState(isNull: String, value: String)
+case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
/**
* Codes and common subexpressions mapping used for subexpression elimination.
@@ -99,6 +115,8 @@ private[codegen] case class NewFunctionSpec(
*/
class CodegenContext {
+ import CodeGenerator._
+
/**
* Holding a list of objects that could be used passed into generated class.
*/
@@ -190,11 +208,11 @@ class CodegenContext {
/**
* Returns the reference of next available slot in current compacted array. The size of each
- * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
+ * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
- if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) {
+ if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) {
val res = s"${arrayNames.last}[$currentIndex]"
currentIndex += 1
res
@@ -241,10 +259,10 @@ class CodegenContext {
* are satisfied:
* 1. forceInline is true
* 2. its type is primitive type and the total number of the inlined mutable variables
- * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD`
+ * is less than `OUTER_CLASS_VARIABLES_THRESHOLD`
* 3. its type is multi-dimensional array
* When a variable is compacted into an array, the max size of the array for compaction
- * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
+ * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`.
*/
def addMutableState(
javaType: String,
@@ -255,7 +273,7 @@ class CodegenContext {
// want to put a primitive type variable at outerClass for performance
val canInlinePrimitive = isPrimitiveType(javaType) &&
- (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
+ (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD)
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
val varName = if (useFreshName) freshName(variableName) else variableName
val initCode = initFunc(varName)
@@ -313,11 +331,11 @@ class CodegenContext {
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
- case StringType => s"$value = $initCode.clone();"
- case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
- case _ => s"$value = $initCode;"
+ case StringType => code"$value = $initCode.clone();"
+ case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
+ case _ => code"$value = $initCode;"
}
- ExprCode(code, "false", value)
+ ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
def declareMutableStates(): String = {
@@ -333,7 +351,7 @@ class CodegenContext {
val length = if (index + 1 == numArrays) {
mutableStateArrays.getCurrentIndex
} else {
- CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT
+ MUTABLESTATEARRAY_SIZE_LIMIT
}
if (javaType.contains("[]")) {
// initializer had an one-dimensional array variable
@@ -389,7 +407,7 @@ class CodegenContext {
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
// Foreach expression that is participating in subexpression elimination, the state to use.
- val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+ var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]
// The collection of sub-expression result resetting methods that need to be called on each row.
val subexprFunctions = mutable.ArrayBuffer.empty[String]
@@ -462,7 +480,7 @@ class CodegenContext {
inlineToOuterClass: Boolean): NewFunctionSpec = {
val (className, classInstance) = if (inlineToOuterClass) {
outerClassName -> ""
- } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) {
+ } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) {
val className = freshName("NestedClass")
val classInstance = freshName("nestedClassInstance")
@@ -531,14 +549,6 @@ class CodegenContext {
extraClasses.append(code)
}
- final val JAVA_BOOLEAN = "boolean"
- final val JAVA_BYTE = "byte"
- final val JAVA_SHORT = "short"
- final val JAVA_INT = "int"
- final val JAVA_LONG = "long"
- final val JAVA_FLOAT = "float"
- final val JAVA_DOUBLE = "double"
-
/**
* The map from a variable name to it's next ID.
*/
@@ -564,213 +574,20 @@ class CodegenContext {
} else {
s"${freshNamePrefix}_$name"
}
- if (freshNameIds.contains(fullName)) {
- val id = freshNameIds(fullName)
- freshNameIds(fullName) = id + 1
- s"$fullName$id"
- } else {
- freshNameIds += fullName -> 1
- fullName
- }
- }
-
- /**
- * Returns the specialized code to access a value from `inputRow` at `ordinal`.
- */
- def getValue(input: String, dataType: DataType, ordinal: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
- case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
- case StringType => s"$input.getUTF8String($ordinal)"
- case BinaryType => s"$input.getBinary($ordinal)"
- case CalendarIntervalType => s"$input.getInterval($ordinal)"
- case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
- case _: ArrayType => s"$input.getArray($ordinal)"
- case _: MapType => s"$input.getMap($ordinal)"
- case NullType => "null"
- case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
- case _ => s"($jt)$input.get($ordinal, null)"
- }
- }
-
- /**
- * Returns the code to update a column in Row for a given DataType.
- */
- def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
- case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
- case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
- // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
- // it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
- case StringType | _: StructType | _: ArrayType | _: MapType =>
- s"$row.update($ordinal, $value.copy())"
- case _ => s"$row.update($ordinal, $value)"
- }
- }
-
- /**
- * Update a column in MutableRow from ExprCode.
- *
- * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
- */
- def updateColumn(
- row: String,
- dataType: DataType,
- ordinal: Int,
- ev: ExprCode,
- nullable: Boolean,
- isVectorized: Boolean = false): String = {
- if (nullable) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
- s"""
- if (!${ev.isNull}) {
- ${setColumn(row, dataType, ordinal, ev.value)};
- } else {
- ${setColumn(row, dataType, ordinal, "null")};
- }
- """
- } else {
- s"""
- if (!${ev.isNull}) {
- ${setColumn(row, dataType, ordinal, ev.value)};
- } else {
- $row.setNullAt($ordinal);
- }
- """
- }
- } else {
- s"""${setColumn(row, dataType, ordinal, ev.value)};"""
- }
- }
-
- /**
- * Returns the specialized code to set a given value in a column vector for a given `DataType`.
- */
- def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
- val jt = javaType(dataType)
- dataType match {
- case _ if isPrimitiveType(jt) =>
- s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
- case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
- case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
- case _ =>
- throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
- }
- }
-
- /**
- * Returns the specialized code to set a given value in a column vector for a given `DataType`
- * that could potentially be nullable.
- */
- def updateColumn(
- vector: String,
- rowId: String,
- dataType: DataType,
- ev: ExprCode,
- nullable: Boolean): String = {
- if (nullable) {
- s"""
- if (!${ev.isNull}) {
- ${setValue(vector, rowId, dataType, ev.value)}
- } else {
- $vector.putNull($rowId);
- }
- """
- } else {
- s"""${setValue(vector, rowId, dataType, ev.value)};"""
- }
- }
-
- /**
- * Returns the specialized code to access a value from a column vector for a given `DataType`.
- */
- def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
- if (dataType.isInstanceOf[StructType]) {
- // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
- // `ordinal` parameter.
- s"$vector.getStruct($rowId)"
- } else {
- getValue(vector, dataType, rowId)
- }
- }
-
- /**
- * Returns the name used in accessor and setter for a Java primitive type.
- */
- def primitiveTypeName(jt: String): String = jt match {
- case JAVA_INT => "Int"
- case _ => boxedType(jt)
- }
-
- def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
-
- /**
- * Returns the Java type for a DataType.
- */
- def javaType(dt: DataType): String = dt match {
- case BooleanType => JAVA_BOOLEAN
- case ByteType => JAVA_BYTE
- case ShortType => JAVA_SHORT
- case IntegerType | DateType => JAVA_INT
- case LongType | TimestampType => JAVA_LONG
- case FloatType => JAVA_FLOAT
- case DoubleType => JAVA_DOUBLE
- case dt: DecimalType => "Decimal"
- case BinaryType => "byte[]"
- case StringType => "UTF8String"
- case CalendarIntervalType => "CalendarInterval"
- case _: StructType => "InternalRow"
- case _: ArrayType => "ArrayData"
- case _: MapType => "MapData"
- case udt: UserDefinedType[_] => javaType(udt.sqlType)
- case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
- case ObjectType(cls) => cls.getName
- case _ => "Object"
+ val id = freshNameIds.getOrElse(fullName, 0)
+ freshNameIds(fullName) = id + 1
+ s"${fullName}_$id"
}
- /**
- * Returns the boxed type in Java.
- */
- def boxedType(jt: String): String = jt match {
- case JAVA_BOOLEAN => "Boolean"
- case JAVA_BYTE => "Byte"
- case JAVA_SHORT => "Short"
- case JAVA_INT => "Integer"
- case JAVA_LONG => "Long"
- case JAVA_FLOAT => "Float"
- case JAVA_DOUBLE => "Double"
- case other => other
- }
-
- def boxedType(dt: DataType): String = boxedType(javaType(dt))
-
- /**
- * Returns the representation of default value for a given Java Type.
- */
- def defaultValue(jt: String): String = jt match {
- case JAVA_BOOLEAN => "false"
- case JAVA_BYTE => "(byte)-1"
- case JAVA_SHORT => "(short)-1"
- case JAVA_INT => "-1"
- case JAVA_LONG => "-1L"
- case JAVA_FLOAT => "-1.0f"
- case JAVA_DOUBLE => "-1.0"
- case _ => "null"
- }
-
- def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
-
/**
* Generates code for equal expression in Java.
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
- case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
- case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
+ case FloatType =>
+ s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)"
+ case DoubleType =>
+ s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
@@ -806,6 +623,7 @@ class CodegenContext {
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
+ val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
@@ -827,8 +645,8 @@ class CodegenContext {
} else if ($isNullB) {
return 1;
} else {
- ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
- ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
+ $jt $elementA = ${getValue("a", elementType, "i")};
+ $jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
@@ -880,6 +698,107 @@ class CodegenContext {
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}
+ /**
+ * Generates code for updating `partialResult` if `item` is smaller than it.
+ *
+ * @param dataType data type of the expressions
+ * @param partialResult `ExprCode` representing the partial result which has to be updated
+ * @param item `ExprCode` representing the new expression to evaluate for the result
+ */
+ def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
+ s"""
+ |if (!${item.isNull} && (${partialResult.isNull} ||
+ | ${genGreater(dataType, partialResult.value, item.value)})) {
+ | ${partialResult.isNull} = false;
+ | ${partialResult.value} = ${item.value};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates code for updating `partialResult` if `item` is greater than it.
+ *
+ * @param dataType data type of the expressions
+ * @param partialResult `ExprCode` representing the partial result which has to be updated
+ * @param item `ExprCode` representing the new expression to evaluate for the result
+ */
+ def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = {
+ s"""
+ |if (!${item.isNull} && (${partialResult.isNull} ||
+ | ${genGreater(dataType, item.value, partialResult.value)})) {
+ | ${partialResult.isNull} = false;
+ | ${partialResult.value} = ${item.value};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates code creating a [[UnsafeArrayData]].
+ *
+ * @param arrayName name of the array to create
+ * @param numElements code representing the number of elements the array should contain
+ * @param elementType data type of the elements in the array
+ * @param additionalErrorMessage string to include in the error message
+ */
+ def createUnsafeArray(
+ arrayName: String,
+ numElements: String,
+ elementType: DataType,
+ additionalErrorMessage: String): String = {
+ val arraySize = freshName("size")
+ val arrayBytes = freshName("arrayBytes")
+
+ s"""
+ |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+ | $numElements,
+ | ${elementType.defaultSize});
+ |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
+ | " bytes of data due to exceeding the limit " +
+ | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." +
+ | "$additionalErrorMessage");
+ |}
+ |byte[] $arrayBytes = new byte[(int)$arraySize];
+ |UnsafeArrayData $arrayName = new UnsafeArrayData();
+ |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
+ |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
+ """.stripMargin
+ }
+
+ /**
+ * Generates code creating a [[UnsafeArrayData]]. The generated code executes
+ * a provided fallback when the size of backing array would exceed the array size limit.
+ * @param arrayName a name of the array to create
+ * @param numElements a piece of code representing the number of elements the array should contain
+ * @param elementSize a size of an element in bytes
+ * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
+ * and getting the backing array as a parameter
+ * @param fallbackCode a piece of code executed when the array size limit is exceeded
+ */
+ def createUnsafeArrayWithFallback(
+ arrayName: String,
+ numElements: String,
+ elementSize: Int,
+ bodyCode: String => String,
+ fallbackCode: String): String = {
+ val arraySize = freshName("size")
+ val arrayBytes = freshName("arrayBytes")
+ s"""
+ |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+ | $numElements,
+ | $elementSize);
+ |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | $fallbackCode
+ |} else {
+ | final byte[] $arrayBytes = new byte[(int)$arraySize];
+ | UnsafeArrayData $arrayName = new UnsafeArrayData();
+ | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
+ | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
+ | ${bodyCode(arrayBytes)}
+ |}
+ """.stripMargin
+ }
+
/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
@@ -900,19 +819,6 @@ class CodegenContext {
}
}
- /**
- * List of java data types that have special accessors and setters in [[InternalRow]].
- */
- val primitiveTypes =
- Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
-
- /**
- * Returns true if the Java type has a special accessor and setter in [[InternalRow]].
- */
- def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
-
- def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
-
/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
@@ -1083,7 +989,7 @@ class CodegenContext {
// for performance reasons, the functions are prepended, instead of appended,
// thus here they are in reversed order
val orderedFunctions = innerClassFunctions.reverse
- if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
+ if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) {
// Adding a new function to each inner class which contains the invocation of all the
// ones which have been added to that inner class. For example,
// private class NestedClass {
@@ -1118,14 +1024,12 @@ class CodegenContext {
newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
f: => Seq[ExprCode]): Seq[ExprCode] = {
val oldsubExprEliminationExprs = subExprEliminationExprs
- subExprEliminationExprs.clear
- newSubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+ subExprEliminationExprs = newSubExprEliminationExprs
val genCodes = f
// Restore previous subExprEliminationExprs
- subExprEliminationExprs.clear
- oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+ subExprEliminationExprs = oldsubExprEliminationExprs
genCodes
}
@@ -1139,7 +1043,7 @@ class CodegenContext {
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
- val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+ val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree)
@@ -1152,10 +1056,10 @@ class CodegenContext {
// Generate the code for this expression tree.
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
- e.foreach(subExprEliminationExprs.put(_, state))
- eval.code.trim
+ e.foreach(localSubExprEliminationExprs.put(_, state))
+ eval.code.toString
}
- SubExprCodes(codes, subExprEliminationExprs.toMap)
+ SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
/**
@@ -1181,7 +1085,7 @@ class CodegenContext {
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
- | ${eval.code.trim}
+ | ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
@@ -1202,8 +1106,10 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
- val state = SubExprEliminationState(isNull, value)
- e.foreach(subExprEliminationExprs.put(_, state))
+ val state = SubExprEliminationState(
+ JavaCode.isNullGlobal(isNull),
+ JavaCode.global(value, expr.dataType))
+ subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
@@ -1226,49 +1132,39 @@ class CodegenContext {
/**
* Register a comment and return the corresponding place holder
+ *
+ * @param placeholderId an optionally specified identifier for the comment's placeholder.
+ * The caller should make sure this identifier is unique within the
+ * compilation unit. If this argument is not specified, a fresh identifier
+ * will be automatically created and used as the placeholder.
+ * @param force whether to force registering the comments
*/
- def registerComment(text: => String): String = {
+ def registerComment(
+ text: => String,
+ placeholderId: String = "",
+ force: Boolean = false): Block = {
// By default, disable comments in generated code because computing the comments themselves can
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
// inputs with wide schemas. For more details on the performance issues that motivated this
// flat, see SPARK-15680.
- if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) {
- val name = freshName("c")
+ if (force ||
+ SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) {
+ val name = if (placeholderId != "") {
+ assert(!placeHolderToComments.contains(placeholderId))
+ placeholderId
+ } else {
+ freshName("c")
+ }
val comment = if (text.contains("\n") || text.contains("\r")) {
text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */")
} else {
s"// $text"
}
placeHolderToComments += (name -> comment)
- s"/*$name*/"
+ code"/*$name*/"
} else {
- ""
- }
- }
-
- /**
- * Returns the length of parameters for a Java method descriptor. `this` contributes one unit
- * and a parameter of type long or double contributes two units. Besides, for nullable parameter,
- * we also need to pass a boolean parameter for the null status.
- */
- def calculateParamLength(params: Seq[Expression]): Int = {
- def paramLengthForExpr(input: Expression): Int = {
- // For a nullable expression, we need to pass in an extra boolean parameter.
- (if (input.nullable) 1 else 0) + javaType(input.dataType) match {
- case JAVA_LONG | JAVA_DOUBLE => 2
- case _ => 1
- }
+ EmptyBlock
}
- // Initial value is 1 for `this`.
- 1 + params.map(paramLengthForExpr(_)).sum
- }
-
- /**
- * In Java, a method descriptor is valid only if it represents method parameters with a total
- * length less than a pre-defined constant.
- */
- def isValidParamLength(paramLength: Int): Boolean = {
- paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
}
}
@@ -1503,4 +1399,267 @@ object CodeGenerator extends Logging {
result
}
})
+
+ /**
+ * Name of Java primitive data type
+ */
+ final val JAVA_BOOLEAN = "boolean"
+ final val JAVA_BYTE = "byte"
+ final val JAVA_SHORT = "short"
+ final val JAVA_INT = "int"
+ final val JAVA_LONG = "long"
+ final val JAVA_FLOAT = "float"
+ final val JAVA_DOUBLE = "double"
+
+ /**
+ * List of java primitive data types
+ */
+ val primitiveTypes =
+ Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
+
+ /**
+ * Returns true if a Java type is Java primitive primitive type
+ */
+ def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
+
+ def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
+
+ /**
+ * Returns the specialized code to access a value from `inputRow` at `ordinal`.
+ */
+ def getValue(input: String, dataType: DataType, ordinal: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
+ case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
+ case StringType => s"$input.getUTF8String($ordinal)"
+ case BinaryType => s"$input.getBinary($ordinal)"
+ case CalendarIntervalType => s"$input.getInterval($ordinal)"
+ case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
+ case _: ArrayType => s"$input.getArray($ordinal)"
+ case _: MapType => s"$input.getMap($ordinal)"
+ case NullType => "null"
+ case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
+ case _ => s"($jt)$input.get($ordinal, null)"
+ }
+ }
+
+ /**
+ * Returns the code to update a column in Row for a given DataType.
+ */
+ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
+ case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+ case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
+ // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
+ // it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
+ case StringType | _: StructType | _: ArrayType | _: MapType =>
+ s"$row.update($ordinal, $value.copy())"
+ case _ => s"$row.update($ordinal, $value)"
+ }
+ }
+
+ /**
+ * Update a column in MutableRow from ExprCode.
+ *
+ * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
+ */
+ def updateColumn(
+ row: String,
+ dataType: DataType,
+ ordinal: Int,
+ ev: ExprCode,
+ nullable: Boolean,
+ isVectorized: Boolean = false): String = {
+ if (nullable) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setColumn(row, dataType, ordinal, ev.value)};
+ |} else {
+ | ${setColumn(row, dataType, ordinal, "null")};
+ |}
+ """.stripMargin
+ } else {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setColumn(row, dataType, ordinal, ev.value)};
+ |} else {
+ | $row.setNullAt($ordinal);
+ |}
+ """.stripMargin
+ }
+ } else {
+ s"""${setColumn(row, dataType, ordinal, ev.value)};"""
+ }
+ }
+
+ /**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`.
+ */
+ def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
+ val jt = javaType(dataType)
+ dataType match {
+ case _ if isPrimitiveType(jt) =>
+ s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
+ case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
+ case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
+ case _ =>
+ throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
+ }
+ }
+
+ /**
+ * Returns the specialized code to set a given value in a column vector for a given `DataType`
+ * that could potentially be nullable.
+ */
+ def updateColumn(
+ vector: String,
+ rowId: String,
+ dataType: DataType,
+ ev: ExprCode,
+ nullable: Boolean): String = {
+ if (nullable) {
+ s"""
+ |if (!${ev.isNull}) {
+ | ${setValue(vector, rowId, dataType, ev.value)}
+ |} else {
+ | $vector.putNull($rowId);
+ |}
+ """.stripMargin
+ } else {
+ s"""${setValue(vector, rowId, dataType, ev.value)};"""
+ }
+ }
+
+ /**
+ * Returns the specialized code to access a value from a column vector for a given `DataType`.
+ */
+ def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
+ if (dataType.isInstanceOf[StructType]) {
+ // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
+ // `ordinal` parameter.
+ s"$vector.getStruct($rowId)"
+ } else {
+ getValue(vector, dataType, rowId)
+ }
+ }
+
+ /**
+ * Returns the name used in accessor and setter for a Java primitive type.
+ */
+ def primitiveTypeName(jt: String): String = jt match {
+ case JAVA_INT => "Int"
+ case _ => boxedType(jt)
+ }
+
+ def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
+
+ /**
+ * Returns the Java type for a DataType.
+ */
+ def javaType(dt: DataType): String = dt match {
+ case BooleanType => JAVA_BOOLEAN
+ case ByteType => JAVA_BYTE
+ case ShortType => JAVA_SHORT
+ case IntegerType | DateType => JAVA_INT
+ case LongType | TimestampType => JAVA_LONG
+ case FloatType => JAVA_FLOAT
+ case DoubleType => JAVA_DOUBLE
+ case _: DecimalType => "Decimal"
+ case BinaryType => "byte[]"
+ case StringType => "UTF8String"
+ case CalendarIntervalType => "CalendarInterval"
+ case _: StructType => "InternalRow"
+ case _: ArrayType => "ArrayData"
+ case _: MapType => "MapData"
+ case udt: UserDefinedType[_] => javaType(udt.sqlType)
+ case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
+ case ObjectType(cls) => cls.getName
+ case _ => "Object"
+ }
+
+ def javaClass(dt: DataType): Class[_] = dt match {
+ case BooleanType => java.lang.Boolean.TYPE
+ case ByteType => java.lang.Byte.TYPE
+ case ShortType => java.lang.Short.TYPE
+ case IntegerType | DateType => java.lang.Integer.TYPE
+ case LongType | TimestampType => java.lang.Long.TYPE
+ case FloatType => java.lang.Float.TYPE
+ case DoubleType => java.lang.Double.TYPE
+ case _: DecimalType => classOf[Decimal]
+ case BinaryType => classOf[Array[Byte]]
+ case StringType => classOf[UTF8String]
+ case CalendarIntervalType => classOf[CalendarInterval]
+ case _: StructType => classOf[InternalRow]
+ case _: ArrayType => classOf[ArrayData]
+ case _: MapType => classOf[MapData]
+ case udt: UserDefinedType[_] => javaClass(udt.sqlType)
+ case ObjectType(cls) => cls
+ case _ => classOf[Object]
+ }
+
+ /**
+ * Returns the boxed type in Java.
+ */
+ def boxedType(jt: String): String = jt match {
+ case JAVA_BOOLEAN => "Boolean"
+ case JAVA_BYTE => "Byte"
+ case JAVA_SHORT => "Short"
+ case JAVA_INT => "Integer"
+ case JAVA_LONG => "Long"
+ case JAVA_FLOAT => "Float"
+ case JAVA_DOUBLE => "Double"
+ case other => other
+ }
+
+ def boxedType(dt: DataType): String = boxedType(javaType(dt))
+
+ /**
+ * Returns the representation of default value for a given Java Type.
+ * @param jt the string name of the Java type
+ * @param typedNull if true, for null literals, return a typed (with a cast) version
+ */
+ def defaultValue(jt: String, typedNull: Boolean): String = jt match {
+ case JAVA_BOOLEAN => "false"
+ case JAVA_BYTE => "(byte)-1"
+ case JAVA_SHORT => "(short)-1"
+ case JAVA_INT => "-1"
+ case JAVA_LONG => "-1L"
+ case JAVA_FLOAT => "-1.0f"
+ case JAVA_DOUBLE => "-1.0"
+ case _ => if (typedNull) s"(($jt)null)" else "null"
+ }
+
+ def defaultValue(dt: DataType, typedNull: Boolean = false): String =
+ defaultValue(javaType(dt), typedNull)
+
+ /**
+ * Returns the length of parameters for a Java method descriptor. `this` contributes one unit
+ * and a parameter of type long or double contributes two units. Besides, for nullable parameter,
+ * we also need to pass a boolean parameter for the null status.
+ */
+ def calculateParamLength(params: Seq[Expression]): Int = {
+ def paramLengthForExpr(input: Expression): Int = {
+ val javaParamLength = javaType(input.dataType) match {
+ case JAVA_LONG | JAVA_DOUBLE => 2
+ case _ => 1
+ }
+ // For a nullable expression, we need to pass in an extra boolean parameter.
+ (if (input.nullable) 1 else 0) + javaParamLength
+ }
+ // Initial value is 1 for `this`.
+ 1 + params.map(paramLengthForExpr).sum
+ }
+
+ /**
+ * In Java, a method descriptor is valid only if it represents method parameters with a total
+ * length less than a pre-defined constant.
+ */
+ def isValidParamLength(paramLength: Int): Boolean = {
+ paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 0322d1dd6a9ff..3f4704d287cbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -44,21 +45,22 @@ trait CodegenFallback extends Expression {
}
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
+ val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
- ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)};
if (!${ev.isNull}) {
- ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
- ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
- """, isNull = "false")
+ $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
+ """, isNull = FalseLiteral)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index b53c0087e7e2d..33d14329ec95c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
expressions: Seq[Expression],
useSubexprElimination: Boolean): MutableProjection = {
val ctx = newCodeGenContext()
- val (validExpr, index) = expressions.zipWithIndex.filter {
+ val validExpr = expressions.zipWithIndex.filter {
case (NoOp, _) => false
case _ => true
- }.unzip
- val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
+ }
+ val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)
// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
- val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
- case (ev, i) =>
- val e = expressions(i)
- val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
- if (e.nullable) {
- val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull")
+ val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
+ case ((e, i), ev) =>
+ val value = JavaCode.global(
+ ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"),
+ e.dataType)
+ val (code, isNull) = if (e.nullable) {
+ val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
(s"""
|${ev.code}
|$isNull = ${ev.isNull};
|$value = ${ev.value};
- """.stripMargin, isNull, value, i)
+ """.stripMargin, JavaCode.isNullGlobal(isNull))
} else {
(s"""
|${ev.code}
|$value = ${ev.value};
- """.stripMargin, ev.isNull, value, i)
+ """.stripMargin, FalseLiteral)
}
+ val update = CodeGenerator.updateColumn(
+ "mutableRow",
+ e.dataType,
+ i,
+ ExprCode(isNull, value),
+ e.nullable)
+ (code, update)
}
// Evaluate all the subexpressions.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
- val updates = validExpr.zip(projectionCodes).map {
- case (e, (_, isNull, value, i)) =>
- val ev = ExprCode("", isNull, value)
- ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
- }
-
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
- val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
+ val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2))
val codeBody = s"""
public java.lang.Object generate(Object[] references) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 4a459571ed634..9a51be6ed5aeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
s"""
${ctx.INPUT_ROW} = a;
boolean $isNullA;
- ${ctx.javaType(order.child.dataType)} $primitiveA;
+ ${CodeGenerator.javaType(order.child.dataType)} $primitiveA;
{
${eval.code}
$isNullA = ${eval.isNull};
@@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
${ctx.INPUT_ROW} = b;
boolean $isNullB;
- ${ctx.javaType(order.child.dataType)} $primitiveB;
+ ${CodeGenerator.javaType(order.child.dataType)} $primitiveB;
{
${eval.code}
$isNullB = ${eval.isNull};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 3dcbb518ba42a..39778661d1c48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import scala.annotation.tailrec
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
/**
@@ -53,7 +55,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
- val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
+ val converter = convertToSafe(
+ ctx,
+ JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
+ dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
@@ -67,14 +72,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
val code =
- s"""
+ code"""
|final InternalRow $tmpInput = $input;
|final Object[] $values = new Object[${schema.length}];
|$allFields
|final InternalRow $output = new $rowClass($values);
""".stripMargin
- ExprCode(code, "false", output)
+ ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow]))
}
private def createCodeForArray(
@@ -90,8 +95,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val arrayClass = classOf[GenericArrayData].getName
val elementConverter = convertToSafe(
- ctx, ctx.getValue(tmpInput, elementType, index), elementType)
- val code = s"""
+ ctx,
+ JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
+ elementType)
+ val code = code"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
final Object[] $values = new Object[$numElements];
@@ -104,7 +111,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final ArrayData $output = new $arrayClass($values);
"""
- ExprCode(code, "false", output)
+ ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData]))
}
private def createCodeForMap(
@@ -118,26 +125,26 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
- val code = s"""
+ val code = code"""
final MapData $tmpInput = $input;
${keyConverter.code}
${valueConverter.code}
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
"""
- ExprCode(code, "false", output)
+ ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData]))
}
@tailrec
private def convertToSafe(
ctx: CodegenContext,
- input: String,
+ input: ExprValue,
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
- case _ => ExprCode("", "false", input)
+ case _ => ExprCode(FalseLiteral, input)
}
protected def create(expressions: Seq[Expression]): Projection = {
@@ -153,7 +160,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
mutableRow.setNullAt($i);
} else {
${converter.code}
- ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)};
+ ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)};
}
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 36ffa8dcdd2b6..8f2a5a0dce943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -32,14 +33,13 @@ import org.apache.spark.sql.types._
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
/** Returns true iff we support this data type. */
- def canSupport(dataType: DataType): Boolean = dataType match {
+ def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match {
case NullType => true
- case t: AtomicType => true
+ case _: AtomicType => true
case _: CalendarIntervalType => true
case t: StructType => t.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
- case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
@@ -47,22 +47,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private def writeStructToBuffer(
ctx: CodegenContext,
input: String,
+ index: String,
fieldTypes: Seq[DataType],
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
+ ExprCode(
+ JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
+ JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
}
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});")
+ val previousCursor = ctx.freshName("previousCursor")
s"""
- final InternalRow $tmpInput = $input;
- if ($tmpInput instanceof UnsafeRow) {
- ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
- } else {
- ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)}
- }
- """
+ |final InternalRow $tmpInput = $input;
+ |if ($tmpInput instanceof UnsafeRow) {
+ | $rowWriter.write($index, (UnsafeRow) $tmpInput);
+ |} else {
+ | // Remember the current cursor so that we can calculate how many bytes are
+ | // written later.
+ | final int $previousCursor = $rowWriter.cursor();
+ | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)}
+ | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
+ |}
+ """.stripMargin
}
private def writeExpressionsToBuffer(
@@ -70,12 +81,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
row: String,
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
- bufferHolder: String,
+ rowWriter: String,
isTopLevel: Boolean = false): String = {
- val rowWriterClass = classOf[UnsafeRowWriter].getName
- val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
- v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});")
-
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
@@ -88,16 +95,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$rowWriter.zeroOutNullBytes();"
}
} else {
- s"$rowWriter.reset();"
+ s"$rowWriter.resetRowWriter();"
}
val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
case ((input, dataType), index) =>
- val dt = dataType match {
- case udt: UserDefinedType[_] => udt.sqlType
- case other => other
- }
- val tmpCursor = ctx.freshName("tmpCursor")
+ val dt = UserDefinedType.sqlType(dataType)
val setNull = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
@@ -106,56 +109,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$rowWriter.setNullAt($index);"
}
- val writeField = dt match {
- case t: StructType =>
- s"""
- // Remember the current cursor so that we can calculate how many bytes are
- // written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case a @ ArrayType(et, _) =>
- s"""
- // Remember the current cursor so that we can calculate how many bytes are
- // written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case m @ MapType(kt, vt, _) =>
- s"""
- // Remember the current cursor so that we can calculate how many bytes are
- // written later.
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
- $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case t: DecimalType =>
- s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"
-
- case NullType => ""
-
- case _ => s"$rowWriter.write($index, ${input.value});"
- }
-
- if (input.isNull == "false") {
+ val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
+ if (input.isNull == FalseLiteral) {
s"""
- ${input.code}
- ${writeField.trim}
- """
+ |${input.code}
+ |${writeField.trim}
+ """.stripMargin
} else {
s"""
- ${input.code}
- if (${input.isNull}) {
- ${setNull.trim}
- } else {
- ${writeField.trim}
- }
- """
+ |${input.code}
+ |if (${input.isNull}) {
+ | ${setNull.trim}
+ |} else {
+ | ${writeField.trim}
+ |}
+ """.stripMargin
}
}
@@ -169,11 +137,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
funcName = "writeFields",
arguments = Seq("InternalRow" -> row))
}
-
s"""
- $resetWriter
- $writeFieldsCode
- """.trim
+ |$resetWriter
+ |$writeFieldsCode
+ """.stripMargin
}
// TODO: if the nullability of array element is correct, we can use it to save null check.
@@ -181,126 +148,119 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
input: String,
elementType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
- val arrayWriterClass = classOf[UnsafeArrayWriter].getName
- val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
- v => s"$v = new $arrayWriterClass();")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
- val et = elementType match {
- case udt: UserDefinedType[_] => udt.sqlType
- case other => other
- }
+ val et = UserDefinedType.sqlType(elementType)
- val jt = ctx.javaType(et)
+ val jt = CodeGenerator.javaType(et)
val elementOrOffsetSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
- case _ if ctx.isPrimitiveType(jt) => et.defaultSize
+ case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize
case _ => 8 // we need 8 bytes to store offset and length
}
- val tmpCursor = ctx.freshName("tmpCursor")
- val element = ctx.getValue(tmpInput, et, index)
- val writeElement = et match {
- case t: StructType =>
- s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case a @ ArrayType(et, _) =>
- s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case m @ MapType(kt, vt, _) =>
- s"""
- final int $tmpCursor = $bufferHolder.cursor;
- ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
- $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- """
-
- case t: DecimalType =>
- s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
-
- case NullType => ""
-
- case _ => s"$arrayWriter.write($index, $element);"
- }
+ val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+ val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter",
+ v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);")
+
+ val element = CodeGenerator.getValue(tmpInput, et, index)
- val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
s"""
- final ArrayData $tmpInput = $input;
- if ($tmpInput instanceof UnsafeArrayData) {
- ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
- } else {
- final int $numElements = $tmpInput.numElements();
- $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
-
- for (int $index = 0; $index < $numElements; $index++) {
- if ($tmpInput.isNullAt($index)) {
- $arrayWriter.setNull$primitiveTypeName($index);
- } else {
- $writeElement
- }
- }
- }
- """
+ |final ArrayData $tmpInput = $input;
+ |if ($tmpInput instanceof UnsafeArrayData) {
+ | $rowWriter.write((UnsafeArrayData) $tmpInput);
+ |} else {
+ | final int $numElements = $tmpInput.numElements();
+ | $arrayWriter.initialize($numElements);
+ |
+ | for (int $index = 0; $index < $numElements; $index++) {
+ | if ($tmpInput.isNullAt($index)) {
+ | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index);
+ | } else {
+ | ${writeElement(ctx, element, index, et, arrayWriter)}
+ | }
+ | }
+ |}
+ """.stripMargin
}
// TODO: if the nullability of value element is correct, we can use it to save null check.
private def writeMapToBuffer(
ctx: CodegenContext,
input: String,
+ index: String,
keyType: DataType,
valueType: DataType,
- bufferHolder: String): String = {
+ rowWriter: String): String = {
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val tmpCursor = ctx.freshName("tmpCursor")
+ val previousCursor = ctx.freshName("previousCursor")
// Writes out unsafe map according to the format described in `UnsafeMapData`.
s"""
- final MapData $tmpInput = $input;
- if ($tmpInput instanceof UnsafeMapData) {
- ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
- } else {
- // preserve 8 bytes to write the key array numBytes later.
- $bufferHolder.grow(8);
- $bufferHolder.cursor += 8;
+ |final MapData $tmpInput = $input;
+ |if ($tmpInput instanceof UnsafeMapData) {
+ | $rowWriter.write($index, (UnsafeMapData) $tmpInput);
+ |} else {
+ | // Remember the current cursor so that we can calculate how many bytes are
+ | // written later.
+ | final int $previousCursor = $rowWriter.cursor();
+ |
+ | // preserve 8 bytes to write the key array numBytes later.
+ | $rowWriter.grow(8);
+ | $rowWriter.increaseCursor(8);
+ |
+ | // Remember the current cursor so that we can write numBytes of key array later.
+ | final int $tmpCursor = $rowWriter.cursor();
+ |
+ | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)}
+ |
+ | // Write the numBytes of key array into the first 8 bytes.
+ | Platform.putLong(
+ | $rowWriter.getBuffer(),
+ | $tmpCursor - 8,
+ | $rowWriter.cursor() - $tmpCursor);
+ |
+ | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)}
+ | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
+ |}
+ """.stripMargin
+ }
- // Remember the current cursor so that we can write numBytes of key array later.
- final int $tmpCursor = $bufferHolder.cursor;
+ private def writeElement(
+ ctx: CodegenContext,
+ input: String,
+ index: String,
+ dt: DataType,
+ writer: String): String = dt match {
+ case t: StructType =>
+ writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
+
+ case ArrayType(et, _) =>
+ val previousCursor = ctx.freshName("previousCursor")
+ s"""
+ |// Remember the current cursor so that we can calculate how many bytes are
+ |// written later.
+ |final int $previousCursor = $writer.cursor();
+ |${writeArrayToBuffer(ctx, input, et, writer)}
+ |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor);
+ """.stripMargin
- ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)}
- // Write the numBytes of key array into the first 8 bytes.
- Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
+ case MapType(kt, vt, _) =>
+ writeMapToBuffer(ctx, input, index, kt, vt, writer)
- ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)}
- }
- """
- }
+ case DecimalType.Fixed(precision, scale) =>
+ s"$writer.write($index, $input, $precision, $scale);"
- /**
- * If the input is already in unsafe format, we don't need to go through all elements/fields,
- * we can directly write it.
- */
- private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = {
- val sizeInBytes = ctx.freshName("sizeInBytes")
- s"""
- final int $sizeInBytes = $input.getSizeInBytes();
- // grow the global buffer before writing data.
- $bufferHolder.grow($sizeInBytes);
- $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
- $bufferHolder.cursor += $sizeInBytes;
- """
+ case NullType => ""
+
+ case _ => s"$writer.write($index, $input);"
}
def createCode(
@@ -316,38 +276,24 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => true
}
- val result = ctx.addMutableState("UnsafeRow", "result",
- v => s"$v = new UnsafeRow(${expressions.length});")
-
- val holderClass = classOf[BufferHolder].getName
- val holder = ctx.addMutableState(holderClass, "holder",
- v => s"$v = new $holderClass($result, ${numVarLenFields * 32});")
-
- val resetBufferHolder = if (numVarLenFields == 0) {
- ""
- } else {
- s"$holder.reset();"
- }
- val updateRowSize = if (numVarLenFields == 0) {
- ""
- } else {
- s"$result.setTotalSize($holder.totalSize());"
- }
+ val rowWriterClass = classOf[UnsafeRowWriter].getName
+ val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
+ v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")
// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
- val writeExpressions =
- writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
+ val writeExpressions = writeExpressionsToBuffer(
+ ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
- s"""
- $resetBufferHolder
- $evalSubexpr
- $writeExpressions
- $updateRowSize
- """
- ExprCode(code, "false", result)
+ code"""
+ |$rowWriter.reset();
+ |$evalSubexpr
+ |$writeExpressions
+ """.stripMargin
+ // `rowWriter` is declared as a class field, so we can access it directly in methods.
+ ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
@@ -372,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val ctx = newCodeGenContext()
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
- val codeBody = s"""
- public java.lang.Object generate(Object[] references) {
- return new SpecificUnsafeProjection(references);
- }
-
- class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {
-
- private Object[] references;
- ${ctx.declareMutableStates()}
-
- public SpecificUnsafeProjection(Object[] references) {
- this.references = references;
- ${ctx.initMutableStates()}
- }
-
- public void initialize(int partitionIndex) {
- ${ctx.initPartition()}
- }
-
- // Scala.Function1 need this
- public java.lang.Object apply(java.lang.Object row) {
- return apply((InternalRow) row);
- }
-
- public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
- ${eval.code.trim}
- return ${eval.value};
- }
-
- ${ctx.declareAddedFunctions()}
- }
- """
+ val codeBody =
+ s"""
+ |public java.lang.Object generate(Object[] references) {
+ | return new SpecificUnsafeProjection(references);
+ |}
+ |
+ |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {
+ |
+ | private Object[] references;
+ | ${ctx.declareMutableStates()}
+ |
+ | public SpecificUnsafeProjection(Object[] references) {
+ | this.references = references;
+ | ${ctx.initMutableStates()}
+ | }
+ |
+ | public void initialize(int partitionIndex) {
+ | ${ctx.initPartition()}
+ | }
+ |
+ | // Scala.Function1 need this
+ | public java.lang.Object apply(java.lang.Object row) {
+ | return apply((InternalRow) row);
+ | }
+ |
+ | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
+ | ${eval.code}
+ | return ${eval.value};
+ | }
+ |
+ | ${ctx.declareAddedFunctions()}
+ |}
+ """.stripMargin
val code = CodeFormatter.stripOverlappingComments(
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
new file mode 100644
index 0000000000000..250ce48d059e0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import java.lang.{Boolean => JBool}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.language.{existentials, implicitConversions}
+
+import org.apache.spark.sql.types.{BooleanType, DataType}
+
+/**
+ * Trait representing an opaque fragments of java code.
+ */
+trait JavaCode {
+ def code: String
+ override def toString: String = code
+}
+
+/**
+ * Utility functions for creating [[JavaCode]] fragments.
+ */
+object JavaCode {
+ /**
+ * Create a java literal.
+ */
+ def literal(v: String, dataType: DataType): LiteralValue = dataType match {
+ case BooleanType if v == "true" => TrueLiteral
+ case BooleanType if v == "false" => FalseLiteral
+ case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType))
+ }
+
+ /**
+ * Create a default literal. This is null for reference types, false for boolean types and
+ * -1 for other primitive types.
+ */
+ def defaultLiteral(dataType: DataType): LiteralValue = {
+ new LiteralValue(
+ CodeGenerator.defaultValue(dataType, typedNull = true),
+ CodeGenerator.javaClass(dataType))
+ }
+
+ /**
+ * Create a local java variable.
+ */
+ def variable(name: String, dataType: DataType): VariableValue = {
+ variable(name, CodeGenerator.javaClass(dataType))
+ }
+
+ /**
+ * Create a local java variable.
+ */
+ def variable(name: String, javaClass: Class[_]): VariableValue = {
+ VariableValue(name, javaClass)
+ }
+
+ /**
+ * Create a local isNull variable.
+ */
+ def isNullVariable(name: String): VariableValue = variable(name, BooleanType)
+
+ /**
+ * Create a global java variable.
+ */
+ def global(name: String, dataType: DataType): GlobalValue = {
+ global(name, CodeGenerator.javaClass(dataType))
+ }
+
+ /**
+ * Create a global java variable.
+ */
+ def global(name: String, javaClass: Class[_]): GlobalValue = {
+ GlobalValue(name, javaClass)
+ }
+
+ /**
+ * Create a global isNull variable.
+ */
+ def isNullGlobal(name: String): GlobalValue = global(name, BooleanType)
+
+ /**
+ * Create an expression fragment.
+ */
+ def expression(code: String, dataType: DataType): SimpleExprValue = {
+ expression(code, CodeGenerator.javaClass(dataType))
+ }
+
+ /**
+ * Create an expression fragment.
+ */
+ def expression(code: String, javaClass: Class[_]): SimpleExprValue = {
+ SimpleExprValue(code, javaClass)
+ }
+
+ /**
+ * Create a isNull expression fragment.
+ */
+ def isNullExpression(code: String): SimpleExprValue = {
+ expression(code, BooleanType)
+ }
+}
+
+/**
+ * A trait representing a block of java code.
+ */
+trait Block extends JavaCode {
+
+ // The expressions to be evaluated inside this block.
+ def exprValues: Set[ExprValue]
+
+ // Returns java code string for this code block.
+ override def toString: String = _marginChar match {
+ case Some(c) => code.stripMargin(c).trim
+ case _ => code.trim
+ }
+
+ def length: Int = toString.length
+
+ def nonEmpty: Boolean = toString.nonEmpty
+
+ // The leading prefix that should be stripped from each line.
+ // By default we strip blanks or control characters followed by '|' from the line.
+ var _marginChar: Option[Char] = Some('|')
+
+ def stripMargin(c: Char): this.type = {
+ _marginChar = Some(c)
+ this
+ }
+
+ def stripMargin: this.type = {
+ _marginChar = Some('|')
+ this
+ }
+
+ // Concatenates this block with other block.
+ def + (other: Block): Block
+}
+
+object Block {
+
+ val CODE_BLOCK_BUFFER_LENGTH: Int = 512
+
+ implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
+
+ implicit class BlockHelper(val sc: StringContext) extends AnyVal {
+ def code(args: Any*): Block = {
+ sc.checkLengths(args)
+ if (sc.parts.length == 0) {
+ EmptyBlock
+ } else {
+ args.foreach {
+ case _: ExprValue =>
+ case _: Int | _: Long | _: Float | _: Double | _: String =>
+ case _: Block =>
+ case other => throw new IllegalArgumentException(
+ s"Can not interpolate ${other.getClass.getName} into code block.")
+ }
+
+ val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
+ CodeBlock(codeParts, blockInputs)
+ }
+ }
+ }
+
+ // Folds eagerly the literal args into the code parts.
+ private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = {
+ val codeParts = ArrayBuffer.empty[String]
+ val blockInputs = ArrayBuffer.empty[JavaCode]
+
+ val strings = parts.iterator
+ val inputs = args.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+
+ buf.append(strings.next)
+ while (strings.hasNext) {
+ val input = inputs.next
+ input match {
+ case _: ExprValue | _: Block =>
+ codeParts += buf.toString
+ buf.clear
+ blockInputs += input.asInstanceOf[JavaCode]
+ case _ =>
+ buf.append(input)
+ }
+ buf.append(strings.next)
+ }
+ if (buf.nonEmpty) {
+ codeParts += buf.toString
+ }
+
+ (codeParts.toSeq, blockInputs.toSeq)
+ }
+}
+
+/**
+ * A block of java code. Including a sequence of code parts and some inputs to this block.
+ * The actual java code is generated by embedding the inputs into the code parts.
+ */
+case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = {
+ blockInputs.flatMap {
+ case b: Block => b.exprValues
+ case e: ExprValue => Set(e)
+ }.toSet
+ }
+
+ override lazy val code: String = {
+ val strings = codeParts.iterator
+ val inputs = blockInputs.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+ buf.append(StringContext.treatEscapes(strings.next))
+ while (strings.hasNext) {
+ buf.append(inputs.next)
+ buf.append(StringContext.treatEscapes(strings.next))
+ }
+ buf.toString
+ }
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(Seq(this, c))
+ case b: Blocks => Blocks(Seq(this) ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+case class Blocks(blocks: Seq[Block]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
+ override lazy val code: String = blocks.map(_.toString).mkString("\n")
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(blocks :+ c)
+ case b: Blocks => Blocks(blocks ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+object EmptyBlock extends Block with Serializable {
+ override val code: String = ""
+ override val exprValues: Set[ExprValue] = Set.empty
+
+ override def + (other: Block): Block = other
+}
+
+/**
+ * A typed java fragment that must be a valid java expression.
+ */
+trait ExprValue extends JavaCode {
+ def javaType: Class[_]
+ def isPrimitive: Boolean = javaType.isPrimitive
+}
+
+object ExprValue {
+ implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code
+}
+
+/**
+ * A java expression fragment.
+ */
+case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue {
+ override def code: String = s"($expr)"
+}
+
+/**
+ * A local variable java expression.
+ */
+case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue {
+ override def code: String = variableName
+}
+
+/**
+ * A global variable java expression.
+ */
+case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue {
+ override def code: String = value
+}
+
+/**
+ * A literal java expression.
+ */
+class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable {
+ override def code: String = value
+
+ override def equals(arg: Any): Boolean = arg match {
+ case l: LiteralValue => l.javaType == javaType && l.value == value
+ case _ => false
+ }
+
+ override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode()
+}
+
+case object TrueLiteral extends LiteralValue("true", JBool.TYPE)
+case object FalseLiteral extends LiteralValue("false", JBool.TYPE)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 4270b987d6de0..d76f3013f0c41 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -18,11 +18,52 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
+import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+
+/**
+ * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
+ * casting.
+ */
+trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
+ with ImplicitCastInputTypes {
+
+ @transient protected lazy val elementType: DataType =
+ inputTypes.head.asInstanceOf[ArrayType].elementType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
+ s"been two ${ArrayType.simpleString}s with same element type, but it's " +
+ s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
+ }
+ }
+}
+
/**
* Given an array or map, returns its size. Returns -1 if null.
@@ -51,11 +92,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = false;
${childGen.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
- (${childGen.value}).numElements();""", isNull = "false")
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
+ (${childGen.value}).numElements();""", isNull = FalseLiteral)
}
}
@@ -87,6 +128,172 @@ case class MapKeys(child: Expression)
override def prettyName: String = "map_keys"
}
+@ExpressionDescription(
+ usage = """
+ _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
+ N-th values of input arrays.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
+ [[1, 2], [2, 3], [3, 4]]
+ > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
+ [[1, 2, 3], [2, 3, 4]]
+ """,
+ since = "2.4.0")
+case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
+
+ override def dataType: DataType = ArrayType(mountSchema)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType])
+
+ private lazy val arrayElementTypes = arrayTypes.map(_.elementType)
+
+ @transient private lazy val mountSchema: StructType = {
+ val fields = children.zip(arrayElementTypes).zipWithIndex.map {
+ case ((expr: NamedExpression, elementType), _) =>
+ StructField(expr.name, elementType, nullable = true)
+ case ((_, elementType), idx) =>
+ StructField(idx.toString, elementType, nullable = true)
+ }
+ StructType(fields)
+ }
+
+ @transient lazy val numberOfArrays: Int = children.length
+
+ @transient lazy val genericArrayData = classOf[GenericArrayData].getName
+
+ def emptyInputGenCode(ev: ExprCode): ExprCode = {
+ ev.copy(code"""
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
+ |boolean ${ev.isNull} = false;
+ """.stripMargin)
+ }
+
+ def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val genericInternalRow = classOf[GenericInternalRow].getName
+ val arrVals = ctx.freshName("arrVals")
+ val biggestCardinality = ctx.freshName("biggestCardinality")
+
+ val currentRow = ctx.freshName("currentRow")
+ val j = ctx.freshName("j")
+ val i = ctx.freshName("i")
+ val args = ctx.freshName("args")
+
+ val evals = children.map(_.genCode(ctx))
+ val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
+ s"""
+ |if ($biggestCardinality != -1) {
+ | ${eval.code}
+ | if (!${eval.isNull}) {
+ | $arrVals[$index] = ${eval.value};
+ | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
+ | } else {
+ | $biggestCardinality = -1;
+ | }
+ |}
+ """.stripMargin
+ }
+
+ val splittedGetValuesAndCardinalities = ctx.splitExpressions(
+ expressions = getValuesAndCardinalities,
+ funcName = "getValuesAndCardinalities",
+ returnType = "int",
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return $biggestCardinality;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
+ arguments =
+ ("ArrayData[]", arrVals) ::
+ ("int", biggestCardinality) :: Nil)
+
+ val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
+ val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
+ s"""
+ |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
+ | $currentRow[$idx] = $g;
+ |} else {
+ | $currentRow[$idx] = null;
+ |}
+ """.stripMargin
+ }
+
+ val getValueForTypeSplitted = ctx.splitExpressions(
+ expressions = getValueForType,
+ funcName = "extractValue",
+ arguments =
+ ("int", i) ::
+ ("Object[]", currentRow) ::
+ ("ArrayData[]", arrVals) :: Nil)
+
+ val initVariables = s"""
+ |ArrayData[] $arrVals = new ArrayData[$numberOfArrays];
+ |int $biggestCardinality = 0;
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = null;
+ """.stripMargin
+
+ ev.copy(code"""
+ |$initVariables
+ |$splittedGetValuesAndCardinalities
+ |boolean ${ev.isNull} = $biggestCardinality == -1;
+ |if (!${ev.isNull}) {
+ | Object[] $args = new Object[$biggestCardinality];
+ | for (int $i = 0; $i < $biggestCardinality; $i ++) {
+ | Object[] $currentRow = new Object[$numberOfArrays];
+ | $getValueForTypeSplitted
+ | $args[$i] = new $genericInternalRow($currentRow);
+ | }
+ | ${ev.value} = new $genericArrayData($args);
+ |}
+ """.stripMargin)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ if (numberOfArrays == 0) {
+ emptyInputGenCode(ev)
+ } else {
+ nonEmptyInputGenCode(ctx, ev)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
+ if (inputArrays.contains(null)) {
+ null
+ } else {
+ val biggestCardinality = if (inputArrays.isEmpty) {
+ 0
+ } else {
+ inputArrays.map(_.numElements()).max
+ }
+
+ val result = new Array[InternalRow](biggestCardinality)
+ val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
+
+ for (i <- 0 until biggestCardinality) {
+ val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
+ if (i < arr.numElements() && !arr.isNullAt(i)) {
+ arr.get(i, arrayElementTypes(index))
+ } else {
+ null
+ }
+ }
+
+ result(i) = InternalRow.apply(currentLayer: _*)
+ }
+ new GenericArrayData(result)
+ }
+ }
+
+ override def prettyName: String = "arrays_zip"
+}
+
/**
* Returns an unordered array containing the values of the map.
*/
@@ -116,47 +323,168 @@ case class MapValues(child: Expression)
}
/**
- * Sorts the input array in ascending / descending order according to the natural ordering of
- * the array elements and returns it.
+ * Returns an unordered array of all entries in the given map.
*/
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.",
+ usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
- > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true);
- ["a","b","c","d"]
- """)
-// scalastyle:on line.size.limit
-case class SortArray(base: Expression, ascendingOrder: Expression)
- extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'));
+ [(1,"a"),(2,"b")]
+ """,
+ since = "2.4.0")
+case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
- def this(e: Expression) = this(e, Literal(true))
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
- override def left: Expression = base
- override def right: Expression = ascendingOrder
- override def dataType: DataType = base.dataType
- override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
+ lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
- override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
- case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
- ascendingOrder match {
- case Literal(_: Boolean, BooleanType) =>
- TypeCheckResult.TypeCheckSuccess
- case _ =>
- TypeCheckResult.TypeCheckFailure(
- "Sort order in second argument requires a boolean literal.")
+ override def dataType: DataType = {
+ ArrayType(
+ StructType(
+ StructField("key", childDataType.keyType, false) ::
+ StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
+ Nil),
+ false)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ val childMap = input.asInstanceOf[MapData]
+ val keys = childMap.keyArray()
+ val values = childMap.valueArray()
+ val length = childMap.numElements()
+ val resultData = new Array[AnyRef](length)
+ var i = 0;
+ while (i < length) {
+ val key = keys.get(i, childDataType.keyType)
+ val value = values.get(i, childDataType.valueType)
+ val row = new GenericInternalRow(Array[Any](key, value))
+ resultData.update(i, row)
+ i += 1
+ }
+ new GenericArrayData(resultData)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => {
+ val numElements = ctx.freshName("numElements")
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val code = if (isKeyPrimitive && isValuePrimitive) {
+ genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
+ } else {
+ genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
}
- case ArrayType(dt, _) =>
- TypeCheckResult.TypeCheckFailure(
- s"$prettyName does not support sorting array of type ${dt.simpleString}")
- case _ =>
- TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
+ s"""
+ |final int $numElements = $c.numElements();
+ |final ArrayData $keys = $c.keyArray();
+ |final ArrayData $values = $c.valueArray();
+ |$code
+ """.stripMargin
+ })
}
+ private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")
+
+ private def getValue(varName: String) = {
+ CodeGenerator.getValue(varName, childDataType.valueType, "z")
+ }
+
+ private def genCodeForPrimitiveElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val unsafeRow = ctx.freshName("unsafeRow")
+ val unsafeArrayData = ctx.freshName("unsafeArrayData")
+ val structsOffset = ctx.freshName("structsOffset")
+ val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
+
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ val wordSize = UnsafeRow.WORD_SIZE
+ val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
+ val structSizeAsLong = structSize + "L"
+ val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+ val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+
+ val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
+ val valueAssignmentChecked = if (childDataType.valueContainsNull) {
+ s"""
+ |if ($values.isNullAt(z)) {
+ | $unsafeRow.setNullAt(1);
+ |} else {
+ | $valueAssignment
+ |}
+ """.stripMargin
+ } else {
+ valueAssignment
+ }
+
+ val assignmentLoop = (byteArray: String) =>
+ s"""
+ |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
+ |UnsafeRow $unsafeRow = new UnsafeRow(2);
+ |for (int z = 0; z < $numElements; z++) {
+ | long offset = $structsOffset + z * $structSizeAsLong;
+ | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
+ | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
+ | $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
+ | $valueAssignmentChecked
+ |}
+ |$arrayData = $unsafeArrayData;
+ """.stripMargin
+
+ ctx.createUnsafeArrayWithFallback(
+ unsafeArrayData,
+ numElements,
+ structSize + wordSize,
+ assignmentLoop,
+ genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
+ }
+
+ private def genCodeForAnyElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val data = ctx.freshName("internalRowArray")
+
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
+ s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
+ } else {
+ getValue(values)
+ }
+
+ s"""
+ |final Object[] $data = new Object[$numElements];
+ |for (int z = 0; z < $numElements; z++) {
+ | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
+ |}
+ |$arrayData = new $genericArrayClass($data);
+ """.stripMargin
+ }
+
+ override def prettyName: String = "map_entries"
+}
+
+/**
+ * Common base class for [[SortArray]] and [[ArraySort]].
+ */
+trait ArraySortLike extends ExpectsInputTypes {
+ protected def arrayExpression: Expression
+
+ protected def nullOrder: NullOrder
+
@transient
private lazy val lt: Comparator[Any] = {
- val ordering = base.dataType match {
+ val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -167,9 +495,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
- -1
+ nullOrder
} else if (o2 == null) {
- 1
+ -nullOrder
} else {
ordering.compare(o1, o2)
}
@@ -179,7 +507,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
@transient
private lazy val gt: Comparator[Any] = {
- val ordering = base.dataType match {
+ val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
@@ -190,9 +518,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
- 1
+ -nullOrder
} else if (o2 == null) {
- -1
+ nullOrder
} else {
-ordering.compare(o1, o2)
}
@@ -200,18 +528,287 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
}
}
- override def nullSafeEval(array: Any, ascending: Any): Any = {
- val elementType = base.dataType.asInstanceOf[ArrayType].elementType
+ def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType
+ def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
+
+ def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
- java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
+ java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}
+ def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
+ val arrayData = classOf[ArrayData].getName
+ val genericArrayData = classOf[GenericArrayData].getName
+ val unsafeArrayData = classOf[UnsafeArrayData].getName
+ val array = ctx.freshName("array")
+ val c = ctx.freshName("c")
+ if (elementType == NullType) {
+ s"${ev.value} = $base.copy();"
+ } else {
+ val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
+ val sortOrder = ctx.freshName("sortOrder")
+ val o1 = ctx.freshName("o1")
+ val o2 = ctx.freshName("o2")
+ val jt = CodeGenerator.javaType(elementType)
+ val comp = if (CodeGenerator.isPrimitiveType(elementType)) {
+ val bt = CodeGenerator.boxedType(elementType)
+ val v1 = ctx.freshName("v1")
+ val v2 = ctx.freshName("v2")
+ s"""
+ |$jt $v1 = (($bt) $o1).${jt}Value();
+ |$jt $v2 = (($bt) $o2).${jt}Value();
+ |int $c = ${ctx.genComp(elementType, v1, v2)};
+ """.stripMargin
+ } else {
+ s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
+ }
+ val nonNullPrimitiveAscendingSort =
+ if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) {
+ val javaType = CodeGenerator.javaType(elementType)
+ val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
+ s"""
+ |if ($order) {
+ | $javaType[] $array = $base.to${primitiveTypeName}Array();
+ | java.util.Arrays.sort($array);
+ | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
+ |} else
+ """.stripMargin
+ } else {
+ ""
+ }
+ s"""
+ |$nonNullPrimitiveAscendingSort
+ |{
+ | Object[] $array = $base.toObjectArray($elementTypeTerm);
+ | final int $sortOrder = $order ? 1 : -1;
+ | java.util.Arrays.sort($array, new java.util.Comparator() {
+ | @Override public int compare(Object $o1, Object $o2) {
+ | if ($o1 == null && $o2 == null) {
+ | return 0;
+ | } else if ($o1 == null) {
+ | return $sortOrder * $nullOrder;
+ | } else if ($o2 == null) {
+ | return -$sortOrder * $nullOrder;
+ | }
+ | $comp
+ | return $sortOrder * $c;
+ | }
+ | });
+ | ${ev.value} = new $genericArrayData($array);
+ |}
+ """.stripMargin
+ }
+ }
+
+}
+
+object ArraySortLike {
+ type NullOrder = Int
+ // Least: place null element at the first of the array for ascending order
+ // Greatest: place null element at the end of the array for ascending order
+ object NullOrder {
+ val Least: NullOrder = -1
+ val Greatest: NullOrder = 1
+ }
+}
+
+/**
+ * Sorts the input array in ascending / descending order according to the natural ordering of
+ * the array elements and returns it.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
+ according to the natural ordering of the array elements. Null elements will be placed
+ at the beginning of the returned array in ascending order or at the end of the returned
+ array in descending order.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
+ [null,"a","b","c","d"]
+ """)
+// scalastyle:on line.size.limit
+case class SortArray(base: Expression, ascendingOrder: Expression)
+ extends BinaryExpression with ArraySortLike {
+
+ def this(e: Expression) = this(e, Literal(true))
+
+ override def left: Expression = base
+ override def right: Expression = ascendingOrder
+ override def dataType: DataType = base.dataType
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
+
+ override def arrayExpression: Expression = base
+ override def nullOrder: NullOrder = NullOrder.Least
+
+ override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
+ case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
+ ascendingOrder match {
+ case Literal(_: Boolean, BooleanType) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(
+ "Sort order in second argument requires a boolean literal.")
+ }
+ case ArrayType(dt, _) =>
+ val dtSimple = dt.simpleString
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
+ }
+
+ override def nullSafeEval(array: Any, ascending: Any): Any = {
+ sortEval(array, ascending.asInstanceOf[Boolean])
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
+ }
+
override def prettyName: String = "sort_array"
}
+
+/**
+ * Sorts the input array in ascending order according to the natural ordering of
+ * the array elements and returns it.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
+ be orderable. Null elements will be placed at the end of the returned array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
+ ["a","b","c","d",null]
+ """,
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike {
+
+ override def dataType: DataType = child.dataType
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ override def arrayExpression: Expression = child
+ override def nullOrder: NullOrder = NullOrder.Greatest
+
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
+ TypeCheckResult.TypeCheckSuccess
+ case ArrayType(dt, _) =>
+ val dtSimple = dt.simpleString
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
+ }
+
+ override def nullSafeEval(array: Any): Any = {
+ sortEval(array, true)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
+ }
+
+ override def prettyName: String = "array_sort"
+}
+
+/**
+ * Returns a reversed string or an array with reverse order of elements.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('Spark SQL');
+ LQS krapS
+ > SELECT _FUNC_(array(2, 1, 4, 3));
+ [3, 4, 1, 2]
+ """,
+ since = "1.5.0",
+ note = "Reverse logic for arrays is available since 2.4.0."
+)
+case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ // Input types are utilized by type coercion in ImplicitTypeCasts.
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
+
+ override def dataType: DataType = child.dataType
+
+ lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+
+ override def nullSafeEval(input: Any): Any = input match {
+ case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
+ case s: UTF8String => s.reverse()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => dataType match {
+ case _: StringType => stringCodeGen(ev, c)
+ case _: ArrayType => arrayCodeGen(ctx, ev, c)
+ })
+ }
+
+ private def stringCodeGen(ev: ExprCode, childName: String): String = {
+ s"${ev.value} = ($childName).reverse();"
+ }
+
+ private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
+ val length = ctx.freshName("length")
+ val javaElementType = CodeGenerator.javaType(elementType)
+ val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)
+
+ val initialization = if (isPrimitiveType) {
+ s"$childName.copy()"
+ } else {
+ s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
+ }
+
+ val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length
+
+ val swapAssigments = if (isPrimitiveType) {
+ val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
+ val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
+ s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
+ |boolean isNullAtL = ${ev.value}.isNullAt(l);
+ |if(!isNullAtK) {
+ | $javaElementType el = ${getCall("k")};
+ | if(!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | } else {
+ | ${ev.value}.setNullAt(k);
+ | }
+ | ${ev.value}.$setFunc(l, el);
+ |} else if (!isNullAtL) {
+ | ${ev.value}.$setFunc(k, ${getCall("l")});
+ | ${ev.value}.setNullAt(l);
+ |}""".stripMargin
+ } else {
+ s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
+ }
+
+ s"""
+ |final int $length = $childName.numElements();
+ |${ev.value} = $initialization;
+ |for(int k = 0; k < $numberOfIterations; k++) {
+ | int l = $length - k - 1;
+ | $swapAssigments
+ |}
+ """.stripMargin
+ }
+
+ override def prettyName: String = "reverse"
+}
+
/**
* Checks if the array (left) has the element (right)
*/
@@ -227,6 +824,9 @@ case class ArrayContains(left: Expression, right: Expression)
override def dataType: DataType = BooleanType
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(right.dataType)
+
override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq.empty
case _ => left.dataType match {
@@ -243,7 +843,7 @@ case class ArrayContains(left: Expression, right: Expression)
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
}
}
@@ -256,7 +856,7 @@ case class ArrayContains(left: Expression, right: Expression)
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
- } else if (v == value) {
+ } else if (ordering.equiv(v, value)) {
return true
}
)
@@ -270,7 +870,7 @@ case class ArrayContains(left: Expression, right: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
- val getValue = ctx.getValue(arr, right.dataType, i)
+ val getValue = CodeGenerator.getValue(arr, right.dataType, i)
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
@@ -287,3 +887,1471 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
+
+/**
+ * Checks if the two arrays contain at least one common element.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
+ true
+ """, since = "2.4.0")
+// scalastyle:off line.size.limit
+case class ArraysOverlap(left: Expression, right: Expression)
+ extends BinaryArrayExpressionWithImplicitCast {
+
+ override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
+ case failure => failure
+ }
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(elementType)
+
+ @transient private lazy val elementTypeSupportEquals = elementType match {
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+
+ @transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
+ fastEval _
+ } else {
+ bruteForceEval _
+ }
+
+ override def dataType: DataType = BooleanType
+
+ override def nullable: Boolean = {
+ left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
+ right.dataType.asInstanceOf[ArrayType].containsNull
+ }
+
+ override def nullSafeEval(a1: Any, a2: Any): Any = {
+ doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
+ }
+
+ /**
+ * A fast implementation which puts all the elements from the smaller array in a set
+ * and then performs a lookup on it for each element of the bigger one.
+ * This eval mode works only for data types which implements properly the equals method.
+ */
+ private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
+ var hasNull = false
+ val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
+ (arr1, arr2)
+ } else {
+ (arr2, arr1)
+ }
+ if (smaller.numElements() > 0) {
+ val smallestSet = new mutable.HashSet[Any]
+ smaller.foreach(elementType, (_, v) =>
+ if (v == null) {
+ hasNull = true
+ } else {
+ smallestSet += v
+ })
+ bigger.foreach(elementType, (_, v1) =>
+ if (v1 == null) {
+ hasNull = true
+ } else if (smallestSet.contains(v1)) {
+ return true
+ }
+ )
+ }
+ if (hasNull) {
+ null
+ } else {
+ false
+ }
+ }
+
+ /**
+ * A slower evaluation which performs a nested loop and supports all the data types.
+ */
+ private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
+ var hasNull = false
+ if (arr1.numElements() > 0 && arr2.numElements() > 0) {
+ arr1.foreach(elementType, (_, v1) =>
+ if (v1 == null) {
+ hasNull = true
+ } else {
+ arr2.foreach(elementType, (_, v2) =>
+ if (v2 == null) {
+ hasNull = true
+ } else if (ordering.equiv(v1, v2)) {
+ return true
+ }
+ )
+ })
+ }
+ if (hasNull) {
+ null
+ } else {
+ false
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (a1, a2) => {
+ val smaller = ctx.freshName("smallerArray")
+ val bigger = ctx.freshName("biggerArray")
+ val comparisonCode = if (elementTypeSupportEquals) {
+ fastCodegen(ctx, ev, smaller, bigger)
+ } else {
+ bruteForceCodegen(ctx, ev, smaller, bigger)
+ }
+ s"""
+ |ArrayData $smaller;
+ |ArrayData $bigger;
+ |if ($a1.numElements() > $a2.numElements()) {
+ | $bigger = $a1;
+ | $smaller = $a2;
+ |} else {
+ | $smaller = $a1;
+ | $bigger = $a2;
+ |}
+ |if ($smaller.numElements() > 0) {
+ | $comparisonCode
+ |}
+ """.stripMargin
+ })
+ }
+
+ /**
+ * Code generation for a fast implementation which puts all the elements from the smaller array
+ * in a set and then performs a lookup on it for each element of the bigger one.
+ * It works only for data types which implements properly the equals method.
+ */
+ private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
+ val i = ctx.freshName("i")
+ val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
+ val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
+ val javaElementClass = CodeGenerator.boxedType(elementType)
+ val javaSet = classOf[java.util.HashSet[_]].getName
+ val set = ctx.freshName("set")
+ val addToSetFromSmallerCode = nullSafeElementCodegen(
+ smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
+ val elementIsInSetCode = nullSafeElementCodegen(
+ bigger,
+ i,
+ s"""
+ |if ($set.contains($getFromBigger)) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = true;
+ | break;
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ s"""
+ |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
+ |for (int $i = 0; $i < $smaller.numElements(); $i ++) {
+ | $addToSetFromSmallerCode
+ |}
+ |for (int $i = 0; $i < $bigger.numElements(); $i ++) {
+ | $elementIsInSetCode
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Code generation for a slower evaluation which performs a nested loop and supports all the data types.
+ */
+ private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
+ val i = ctx.freshName("i")
+ val j = ctx.freshName("j")
+ val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
+ val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
+ val compareValues = nullSafeElementCodegen(
+ smaller,
+ j,
+ s"""
+ |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = true;
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ val isInSmaller = nullSafeElementCodegen(
+ bigger,
+ i,
+ s"""
+ |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
+ | $compareValues
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ s"""
+ |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
+ | $isInSmaller
+ |}
+ """.stripMargin
+ }
+
+ def nullSafeElementCodegen(
+ arrayVar: String,
+ index: String,
+ code: String,
+ isNullCode: String): String = {
+ if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
+ s"""
+ |if ($arrayVar.isNullAt($index)) {
+ | $isNullCode
+ |} else {
+ | $code
+ |}
+ """.stripMargin
+ } else {
+ code
+ }
+ }
+
+ override def prettyName: String = "arrays_overlap"
+}
+
+/**
+ * Slices an array according to the requested start index and length
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
+ [2,3]
+ > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
+ [3,4]
+ """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class Slice(x: Expression, start: Expression, length: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = x.dataType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
+
+ override def children: Seq[Expression] = Seq(x, start, length)
+
+ lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
+
+ override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
+ val startInt = startVal.asInstanceOf[Int]
+ val lengthInt = lengthVal.asInstanceOf[Int]
+ val arr = xVal.asInstanceOf[ArrayData]
+ val startIndex = if (startInt == 0) {
+ throw new RuntimeException(
+ s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
+ } else if (startInt < 0) {
+ startInt + arr.numElements()
+ } else {
+ startInt - 1
+ }
+ if (lengthInt < 0) {
+ throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
+ "length must be greater than or equal to 0.")
+ }
+ // startIndex can be negative if start is negative and its absolute value is greater than the
+ // number of elements in the array
+ if (startIndex < 0 || startIndex >= arr.numElements()) {
+ return new GenericArrayData(Array.empty[AnyRef])
+ }
+ val data = arr.toSeq[AnyRef](elementType)
+ new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (x, start, length) => {
+ val startIdx = ctx.freshName("startIdx")
+ val resLength = ctx.freshName("resLength")
+ val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
+ s"""
+ |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
+ |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
+ |if ($start == 0) {
+ | throw new RuntimeException("Unexpected value for start in function $prettyName: "
+ | + "SQL array indices start at 1.");
+ |} else if ($start < 0) {
+ | $startIdx = $start + $x.numElements();
+ |} else {
+ | // arrays in SQL are 1-based instead of 0-based
+ | $startIdx = $start - 1;
+ |}
+ |if ($length < 0) {
+ | throw new RuntimeException("Unexpected value for length in function $prettyName: "
+ | + "length must be greater than or equal to 0.");
+ |} else if ($length > $x.numElements() - $startIdx) {
+ | $resLength = $x.numElements() - $startIdx;
+ |} else {
+ | $resLength = $length;
+ |}
+ |${genCodeForResult(ctx, ev, x, startIdx, resLength)}
+ """.stripMargin
+ })
+ }
+
+ def genCodeForResult(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ inputArray: String,
+ startIdx: String,
+ resLength: String): String = {
+ val values = ctx.freshName("values")
+ val i = ctx.freshName("i")
+ val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
+ if (!CodeGenerator.isPrimitiveType(elementType)) {
+ val arrayClass = classOf[GenericArrayData].getName
+ s"""
+ |Object[] $values;
+ |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+ | $values = new Object[0];
+ |} else {
+ | $values = new Object[$resLength];
+ | for (int $i = 0; $i < $resLength; $i ++) {
+ | $values[$i] = $getValue;
+ | }
+ |}
+ |${ev.value} = new $arrayClass($values);
+ """.stripMargin
+ } else {
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ s"""
+ |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+ | $resLength = 0;
+ |}
+ |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")}
+ |for (int $i = 0; $i < $resLength; $i ++) {
+ | if ($inputArray.isNullAt($i + $startIdx)) {
+ | $values.setNullAt($i);
+ | } else {
+ | $values.set$primitiveValueTypeName($i, $getValue);
+ | }
+ |}
+ |${ev.value} = $values;
+ """.stripMargin
+ }
+ }
+}
+
+/**
+ * Creates a String containing all the elements of the input array separated by the delimiter.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array
+ using the delimiter and an optional string to replace nulls. If no value is set for
+ nullReplacement, any null value is filtered.""",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array('hello', 'world'), ' ');
+ hello world
+ > SELECT _FUNC_(array('hello', null ,'world'), ' ');
+ hello world
+ > SELECT _FUNC_(array('hello', null ,'world'), ' ', ',');
+ hello , world
+ """, since = "2.4.0")
+case class ArrayJoin(
+ array: Expression,
+ delimiter: Expression,
+ nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes {
+
+ def this(array: Expression, delimiter: Expression) = this(array, delimiter, None)
+
+ def this(array: Expression, delimiter: Expression, nullReplacement: Expression) =
+ this(array, delimiter, Some(nullReplacement))
+
+ override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
+ Seq(ArrayType(StringType), StringType, StringType)
+ } else {
+ Seq(ArrayType(StringType), StringType)
+ }
+
+ override def children: Seq[Expression] = if (nullReplacement.isDefined) {
+ Seq(array, delimiter, nullReplacement.get)
+ } else {
+ Seq(array, delimiter)
+ }
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def eval(input: InternalRow): Any = {
+ val arrayEval = array.eval(input)
+ if (arrayEval == null) return null
+ val delimiterEval = delimiter.eval(input)
+ if (delimiterEval == null) return null
+ val nullReplacementEval = nullReplacement.map(_.eval(input))
+ if (nullReplacementEval.contains(null)) return null
+
+ val buffer = new UTF8StringBuilder()
+ var firstItem = true
+ val nullHandling = nullReplacementEval match {
+ case Some(rep) => (prependDelimiter: Boolean) => {
+ if (!prependDelimiter) {
+ buffer.append(delimiterEval.asInstanceOf[UTF8String])
+ }
+ buffer.append(rep.asInstanceOf[UTF8String])
+ true
+ }
+ case None => (_: Boolean) => false
+ }
+ arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => {
+ if (item == null) {
+ if (nullHandling(firstItem)) {
+ firstItem = false
+ }
+ } else {
+ if (!firstItem) {
+ buffer.append(delimiterEval.asInstanceOf[UTF8String])
+ }
+ buffer.append(item.asInstanceOf[UTF8String])
+ firstItem = false
+ }
+ })
+ buffer.build()
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val code = nullReplacement match {
+ case Some(replacement) =>
+ val replacementGen = replacement.genCode(ctx)
+ val nullHandling = (buffer: String, delimiter: String, firstItem: String) => {
+ s"""
+ |if (!$firstItem) {
+ | $buffer.append($delimiter);
+ |}
+ |$buffer.append(${replacementGen.value});
+ |$firstItem = false;
+ """.stripMargin
+ }
+ val execCode = if (replacement.nullable) {
+ ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) {
+ genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+ }
+ } else {
+ genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
+ }
+ s"""
+ |${replacementGen.code}
+ |$execCode
+ """.stripMargin
+ case None => genCodeForArrayAndDelimiter(ctx, ev,
+ (_: String, _: String, _: String) => "// nulls are ignored")
+ }
+ if (nullable) {
+ ev.copy(
+ code"""
+ |boolean ${ev.isNull} = true;
+ |UTF8String ${ev.value} = null;
+ |$code
+ """.stripMargin)
+ } else {
+ ev.copy(
+ code"""
+ |UTF8String ${ev.value} = null;
+ |$code
+ """.stripMargin, FalseLiteral)
+ }
+ }
+
+ private def genCodeForArrayAndDelimiter(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ nullEval: (String, String, String) => String): String = {
+ val arrayGen = array.genCode(ctx)
+ val delimiterGen = delimiter.genCode(ctx)
+ val buffer = ctx.freshName("buffer")
+ val bufferClass = classOf[UTF8StringBuilder].getName
+ val i = ctx.freshName("i")
+ val firstItem = ctx.freshName("firstItem")
+ val resultCode =
+ s"""
+ |$bufferClass $buffer = new $bufferClass();
+ |boolean $firstItem = true;
+ |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) {
+ | if (${arrayGen.value}.isNullAt($i)) {
+ | ${nullEval(buffer, delimiterGen.value, firstItem)}
+ | } else {
+ | if (!$firstItem) {
+ | $buffer.append(${delimiterGen.value});
+ | }
+ | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)});
+ | $firstItem = false;
+ | }
+ |}
+ |${ev.value} = $buffer.build();""".stripMargin
+
+ if (array.nullable || delimiter.nullable) {
+ arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) {
+ delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) {
+ s"""
+ |${ev.isNull} = false;
+ |$resultCode""".stripMargin
+ }
+ }
+ } else {
+ s"""
+ |${arrayGen.code}
+ |${delimiterGen.code}
+ |$resultCode""".stripMargin
+ }
+ }
+
+ override def dataType: DataType = StringType
+
+}
+
+/**
+ * Returns the minimum value in the array.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 20, null, 3));
+ 1
+ """, since = "2.4.0")
+case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+ TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val javaType = CodeGenerator.javaType(dataType)
+ val i = ctx.freshName("i")
+ val item = ExprCode(EmptyBlock,
+ isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
+ value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
+ ev.copy(code =
+ code"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = true;
+ |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |if (!${childGen.isNull}) {
+ | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
+ | ${ctx.reassignIfSmaller(dataType, ev, item)}
+ | }
+ |}
+ """.stripMargin)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ var min: Any = null
+ input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
+ if (item != null && (min == null || ordering.lt(item, min))) {
+ min = item
+ }
+ )
+ min
+ }
+
+ override def dataType: DataType = child.dataType match {
+ case ArrayType(dt, _) => dt
+ case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
+ }
+
+ override def prettyName: String = "array_min"
+}
+
+/**
+ * Returns the maximum value in the array.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 20, null, 3));
+ 20
+ """, since = "2.4.0")
+case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+ private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+ TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childGen = child.genCode(ctx)
+ val javaType = CodeGenerator.javaType(dataType)
+ val i = ctx.freshName("i")
+ val item = ExprCode(EmptyBlock,
+ isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
+ value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
+ ev.copy(code =
+ code"""
+ |${childGen.code}
+ |boolean ${ev.isNull} = true;
+ |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ |if (!${childGen.isNull}) {
+ | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
+ | ${ctx.reassignIfGreater(dataType, ev, item)}
+ | }
+ |}
+ """.stripMargin)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ var max: Any = null
+ input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
+ if (item != null && (max == null || ordering.gt(item, max))) {
+ max = item
+ }
+ )
+ max
+ }
+
+ override def dataType: DataType = child.dataType match {
+ case ArrayType(dt, _) => dt
+ case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
+ }
+
+ override def prettyName: String = "array_max"
+}
+
+
+/**
+ * Returns the position of the first occurrence of element in the given array as long.
+ * Returns 0 if the given value could not be found in the array. Returns null if either of
+ * the arguments are null
+ *
+ * NOTE: that this is not zero based, but 1-based index. The first element in the array has
+ * index 1.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(3, 2, 1), 1);
+ 3
+ """,
+ since = "2.4.0")
+case class ArrayPosition(left: Expression, right: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(right.dataType)
+
+ override def dataType: DataType = LongType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ val elementType = left.dataType match {
+ case t: ArrayType => t.elementType
+ case _ => AnyDataType
+ }
+ Seq(ArrayType, elementType)
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+ }
+ }
+
+ override def nullSafeEval(arr: Any, value: Any): Any = {
+ arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+ if (v != null && ordering.equiv(v, value)) {
+ return (i + 1).toLong
+ }
+ )
+ 0L
+ }
+
+ override def prettyName: String = "array_position"
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (arr, value) => {
+ val pos = ctx.freshName("arrayPosition")
+ val i = ctx.freshName("i")
+ val getValue = CodeGenerator.getValue(arr, right.dataType, i)
+ s"""
+ |int $pos = 0;
+ |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
+ | $pos = $i + 1;
+ | break;
+ | }
+ |}
+ |${ev.value} = (long) $pos;
+ """.stripMargin
+ })
+ }
+}
+
+/**
+ * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
+ accesses elements from the last to the first. Returns NULL if the index exceeds the length
+ of the array.
+
+ _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), 2);
+ 2
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
+ "b"
+ """,
+ since = "2.4.0")
+case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
+
+ override def dataType: DataType = left.dataType match {
+ case ArrayType(elementType, _) => elementType
+ case MapType(_, valueType, _) => valueType
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(TypeCollection(ArrayType, MapType),
+ left.dataType match {
+ case _: ArrayType => IntegerType
+ case _: MapType => left.dataType.asInstanceOf[MapType].keyType
+ case _ => AnyDataType // no match for a wrong 'left' expression type
+ }
+ )
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
+ TypeUtils.checkForOrderingExpr(
+ left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
+ case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def nullable: Boolean = true
+
+ override def nullSafeEval(value: Any, ordinal: Any): Any = {
+ left.dataType match {
+ case _: ArrayType =>
+ val array = value.asInstanceOf[ArrayData]
+ val index = ordinal.asInstanceOf[Int]
+ if (array.numElements() < math.abs(index)) {
+ null
+ } else {
+ val idx = if (index == 0) {
+ throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
+ } else if (index > 0) {
+ index - 1
+ } else {
+ array.numElements() + index
+ }
+ if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) {
+ null
+ } else {
+ array.get(idx, dataType)
+ }
+ }
+ case _: MapType =>
+ getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ left.dataType match {
+ case _: ArrayType =>
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ val index = ctx.freshName("elementAtIndex")
+ val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($eval1.isNullAt($index)) {
+ | ${ev.isNull} = true;
+ |} else
+ """.stripMargin
+ } else {
+ ""
+ }
+ s"""
+ |int $index = (int) $eval2;
+ |if ($eval1.numElements() < Math.abs($index)) {
+ | ${ev.isNull} = true;
+ |} else {
+ | if ($index == 0) {
+ | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
+ | } else if ($index > 0) {
+ | $index--;
+ | } else {
+ | $index += $eval1.numElements();
+ | }
+ | $nullCheck
+ | {
+ | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
+ | }
+ |}
+ """.stripMargin
+ })
+ case _: MapType =>
+ doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
+ }
+ }
+
+ override def prettyName: String = "element_at"
+}
+
+/**
+ * Concatenates multiple input columns together into a single column.
+ * The function works with strings, binary and compatible array columns.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('Spark', 'SQL');
+ SparkSQL
+ > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
+ | [1,2,3,4,5,6]
+ """)
+case class Concat(children: Seq[Expression]) extends Expression {
+
+ private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ val allowedTypes = Seq(StringType, BinaryType, ArrayType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ val childTypes = children.map(_.dataType)
+ if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should have been ${StringType.simpleString}," +
+ s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " +
+ childTypes.map(_.simpleString).mkString("[", ", ", "]"))
+ }
+ TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
+ }
+ }
+
+ override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
+
+ lazy val javaType: String = CodeGenerator.javaType(dataType)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def eval(input: InternalRow): Any = dataType match {
+ case BinaryType =>
+ val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+ ByteArray.concat(inputs: _*)
+ case StringType =>
+ val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+ UTF8String.concat(inputs : _*)
+ case ArrayType(elementType, _) =>
+ val inputs = children.toStream.map(_.eval(input))
+ if (inputs.contains(null)) {
+ null
+ } else {
+ val arrayData = inputs.map(_.asInstanceOf[ArrayData])
+ val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
+ if (numberOfElements > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
+ s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+ }
+ val finalData = new Array[AnyRef](numberOfElements.toInt)
+ var position = 0
+ for(ad <- arrayData) {
+ val arr = ad.toObjectArray(elementType)
+ Array.copy(arr, 0, finalData, position, arr.length)
+ position += arr.length
+ }
+ new GenericArrayData(finalData)
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val evals = children.map(_.genCode(ctx))
+ val args = ctx.freshName("args")
+
+ val inputs = evals.zipWithIndex.map { case (eval, index) =>
+ s"""
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $args[$index] = ${eval.value};
+ }
+ """
+ }
+
+ val (concatenator, initCode) = dataType match {
+ case BinaryType =>
+ (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
+ case StringType =>
+ ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
+ case ArrayType(elementType, _) =>
+ val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
+ genCodeForPrimitiveArrays(ctx, elementType)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, elementType)
+ }
+ (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
+ }
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = inputs,
+ funcName = "valueConcat",
+ extraArguments = (s"$javaType[]", args) :: Nil)
+ ev.copy(code"""
+ $initCode
+ $codes
+ $javaType ${ev.value} = $concatenator.concat($args);
+ boolean ${ev.isNull} = ${ev.value} == null;
+ """)
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
+ val numElements = ctx.freshName("numElements")
+ val code = s"""
+ |long $numElements = 0L;
+ |for (int z = 0; z < ${children.length}; z++) {
+ | $numElements += args[z].numElements();
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (code, numElements)
+ }
+
+ private def nullArgumentProtection() : String = {
+ if (nullable) {
+ s"""
+ |for (int z = 0; z < ${children.length}; z++) {
+ | if (args[z] == null) return null;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+ }
+
+ private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
+ val counter = ctx.freshName("counter")
+ val arrayData = ctx.freshName("arrayData")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+ s"""
+ |new Object() {
+ | public ArrayData concat($javaType[] args) {
+ | ${nullArgumentProtection()}
+ | $numElemCode
+ | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < args[y].numElements(); z++) {
+ | if (args[y].isNullAt(z)) {
+ | $arrayData.setNullAt($counter);
+ | } else {
+ | $arrayData.set$primitiveValueTypeName(
+ | $counter,
+ | ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
+ | );
+ | }
+ | $counter++;
+ | }
+ | }
+ | return $arrayData;
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayData = ctx.freshName("arrayObjects")
+ val counter = ctx.freshName("counter")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+ s"""
+ |new Object() {
+ | public ArrayData concat($javaType[] args) {
+ | ${nullArgumentProtection()}
+ | $numElemCode
+ | Object[] $arrayData = new Object[(int)$numElemName];
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < args[y].numElements(); z++) {
+ | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
+ | $counter++;
+ | }
+ | }
+ | return new $genericArrayClass($arrayData);
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ override def toString: String = s"concat(${children.mkString(", ")})"
+
+ override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
+}
+
+/**
+ * Transforms an array of arrays into a single array.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(array(1, 2), array(3, 4));
+ [1,2,3,4]
+ """,
+ since = "2.4.0")
+case class Flatten(child: Expression) extends UnaryExpression {
+
+ private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
+
+ override def nullable: Boolean = child.nullable || childDataType.containsNull
+
+ override def dataType: DataType = childDataType.elementType
+
+ lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case ArrayType(_: ArrayType, _) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(
+ s"The argument should be an array of arrays, " +
+ s"but '${child.sql}' is of ${child.dataType.simpleString} type."
+ )
+ }
+
+ override def nullSafeEval(child: Any): Any = {
+ val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)
+
+ if (elements.contains(null)) {
+ null
+ } else {
+ val arrayData = elements.map(_.asInstanceOf[ArrayData])
+ val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
+ if (numberOfElements > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
+ s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+ }
+ val flattenedData = new Array(numberOfElements.toInt)
+ var position = 0
+ for (ad <- arrayData) {
+ val arr = ad.toObjectArray(elementType)
+ Array.copy(arr, 0, flattenedData, position, arr.length)
+ position += arr.length
+ }
+ new GenericArrayData(flattenedData)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => {
+ val code = if (CodeGenerator.isPrimitiveType(elementType)) {
+ genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
+ } else {
+ genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
+ }
+ if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
+ })
+ }
+
+ private def nullElementsProtection(
+ ev: ExprCode,
+ childVariableName: String,
+ coreLogic: String): String = {
+ s"""
+ |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
+ | ${ev.isNull} |= $childVariableName.isNullAt(z);
+ |}
+ |if (!${ev.isNull}) {
+ | $coreLogic
+ |}
+ """.stripMargin
+ }
+
+ private def genCodeForNumberOfElements(
+ ctx: CodegenContext,
+ childVariableName: String) : (String, String) = {
+ val variableName = ctx.freshName("numElements")
+ val code = s"""
+ |long $variableName = 0;
+ |for (int z = 0; z < $childVariableName.numElements(); z++) {
+ | $variableName += $childVariableName.getArray(z).numElements();
+ |}
+ |if ($variableName > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
+ | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+ (code, variableName)
+ }
+
+ private def genCodeForFlattenOfPrimitiveElements(
+ ctx: CodegenContext,
+ childVariableName: String,
+ arrayDataName: String): String = {
+ val counter = ctx.freshName("counter")
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
+
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+ s"""
+ |$numElemCode
+ |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")}
+ |int $counter = 0;
+ |for (int k = 0; k < $childVariableName.numElements(); k++) {
+ | ArrayData arr = $childVariableName.getArray(k);
+ | for (int l = 0; l < arr.numElements(); l++) {
+ | if (arr.isNullAt(l)) {
+ | $tempArrayDataName.setNullAt($counter);
+ | } else {
+ | $tempArrayDataName.set$primitiveValueTypeName(
+ | $counter,
+ | ${CodeGenerator.getValue("arr", elementType, "l")}
+ | );
+ | }
+ | $counter++;
+ | }
+ |}
+ |$arrayDataName = $tempArrayDataName;
+ """.stripMargin
+ }
+
+ private def genCodeForFlattenOfNonPrimitiveElements(
+ ctx: CodegenContext,
+ childVariableName: String,
+ arrayDataName: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayName = ctx.freshName("arrayObject")
+ val counter = ctx.freshName("counter")
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
+
+ s"""
+ |$numElemCode
+ |Object[] $arrayName = new Object[(int)$numElemName];
+ |int $counter = 0;
+ |for (int k = 0; k < $childVariableName.numElements(); k++) {
+ | ArrayData arr = $childVariableName.getArray(k);
+ | for (int l = 0; l < arr.numElements(); l++) {
+ | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")};
+ | $counter++;
+ | }
+ |}
+ |$arrayDataName = new $genericArrayClass($arrayName);
+ """.stripMargin
+ }
+
+ override def prettyName: String = "flatten"
+}
+
+/**
+ * Returns the array containing the given input value (left) count (right) times.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(element, count) - Returns the array containing element count times.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('123', 2);
+ ['123', '123']
+ """,
+ since = "2.4.0")
+case class ArrayRepeat(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
+
+ override def nullable: Boolean = right.nullable
+
+ override def eval(input: InternalRow): Any = {
+ val count = right.eval(input)
+ if (count == null) {
+ null
+ } else {
+ if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
+ s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ }
+ val element = left.eval(input)
+ new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
+ }
+ }
+
+ override def prettyName: String = "array_repeat"
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val leftGen = left.genCode(ctx)
+ val rightGen = right.genCode(ctx)
+ val element = leftGen.value
+ val count = rightGen.value
+ val et = dataType.elementType
+
+ val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
+ genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value)
+ } else {
+ genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value)
+ }
+ val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
+
+ ev.copy(code =
+ code"""
+ |boolean ${ev.isNull} = false;
+ |${leftGen.code}
+ |${rightGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
+ | ${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """.stripMargin)
+ }
+
+ private def nullElementsProtection(
+ ev: ExprCode,
+ rightIsNull: String,
+ coreLogic: String): String = {
+ if (nullable) {
+ s"""
+ |if ($rightIsNull) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${coreLogic}
+ |}
+ """.stripMargin
+ } else {
+ coreLogic
+ }
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
+ val numElements = ctx.freshName("numElements")
+ val numElementsCode =
+ s"""
+ |int $numElements = 0;
+ |if ($count > 0) {
+ | $numElements = $count;
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (numElements, numElementsCode)
+ }
+
+ private def genCodeForPrimitiveElement(
+ ctx: CodegenContext,
+ elementType: DataType,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ val errorMessage = s" $prettyName failed."
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+ | }
+ |} else {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.setNullAt(k);
+ | }
+ |}
+ |$arrayDataName = $tempArrayDataName;
+ """.stripMargin
+ }
+
+ private def genCodeForNonPrimitiveElement(
+ ctx: CodegenContext,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayName = ctx.freshName("arrayObject")
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |Object[] $arrayName = new Object[(int)$numElemName];
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $numElemName; k++) {
+ | $arrayName[k] = $element;
+ | }
+ |}
+ |$arrayDataName = new $genericArrayClass($arrayName);
+ """.stripMargin
+ }
+
+}
+
+/**
+ * Remove all elements that equal to element from the given array
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3, null, 3), 3);
+ [1,2,null]
+ """, since = "2.4.0")
+case class ArrayRemove(left: Expression, right: Expression)
+ extends BinaryExpression with ImplicitCastInputTypes {
+
+ override def dataType: DataType = left.dataType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ val elementType = left.dataType match {
+ case t: ArrayType => t.elementType
+ case _ => AnyDataType
+ }
+ Seq(ArrayType, elementType)
+ }
+
+ lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(right.dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+ }
+ }
+
+ override def nullSafeEval(arr: Any, value: Any): Any = {
+ val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements())
+ var pos = 0
+ arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+ if (v == null || !ordering.equiv(v, value)) {
+ newArray(pos) = v
+ pos += 1
+ }
+ )
+ new GenericArrayData(newArray.slice(0, pos))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (arr, value) => {
+ val numsToRemove = ctx.freshName("numsToRemove")
+ val newArraySize = ctx.freshName("newArraySize")
+ val i = ctx.freshName("i")
+ val getValue = CodeGenerator.getValue(arr, elementType, i)
+ val isEqual = ctx.genEqual(elementType, value, getValue)
+ s"""
+ |int $numsToRemove = 0;
+ |for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ | if (!$arr.isNullAt($i) && $isEqual) {
+ | $numsToRemove = $numsToRemove + 1;
+ | }
+ |}
+ |int $newArraySize = $arr.numElements() - $numsToRemove;
+ |${genCodeForResult(ctx, ev, arr, value, newArraySize)}
+ """.stripMargin
+ })
+ }
+
+ def genCodeForResult(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ inputArray: String,
+ value: String,
+ newArraySize: String): String = {
+ val values = ctx.freshName("values")
+ val i = ctx.freshName("i")
+ val pos = ctx.freshName("pos")
+ val getValue = CodeGenerator.getValue(inputArray, elementType, i)
+ val isEqual = ctx.genEqual(elementType, value, getValue)
+ if (!CodeGenerator.isPrimitiveType(elementType)) {
+ val arrayClass = classOf[GenericArrayData].getName
+ s"""
+ |int $pos = 0;
+ |Object[] $values = new Object[$newArraySize];
+ |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+ | if ($inputArray.isNullAt($i)) {
+ | $values[$pos] = null;
+ | $pos = $pos + 1;
+ | }
+ | else {
+ | if (!($isEqual)) {
+ | $values[$pos] = $getValue;
+ | $pos = $pos + 1;
+ | }
+ | }
+ |}
+ |${ev.value} = new $arrayClass($values);
+ """.stripMargin
+ } else {
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ s"""
+ |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")}
+ |int $pos = 0;
+ |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+ | if ($inputArray.isNullAt($i)) {
+ | $values.setNullAt($pos);
+ | $pos = $pos + 1;
+ | }
+ | else {
+ | if (!($isEqual)) {
+ | $values.set$primitiveValueTypeName($pos, $getValue);
+ | $pos = $pos + 1;
+ | }
+ | }
+ |}
+ |${ev.value} = $values;
+ """.stripMargin
+ }
+ }
+
+ override def prettyName: String = "array_remove"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 047b80ac5289c..0a5f8a907b50a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -21,7 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -63,9 +64,9 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + assigns + postprocess,
- value = arrayData,
- isNull = "false")
+ code = code"${preprocess}${assigns}${postprocess}",
+ value = JavaCode.variable(arrayData, dataType),
+ isNull = FalseLiteral)
}
override def prettyName: String = "array"
@@ -90,7 +91,7 @@ private [sql] object GenArrayData {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length
- if (!ctx.isPrimitiveType(elementType)) {
+ if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
@@ -124,7 +125,7 @@ private [sql] object GenArrayData {
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
- val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
s"$arrayDataName.setNullAt($i);"
@@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
val code =
- s"""
+ code"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
$assignKeys
@@ -235,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def prettyName: String = "map"
}
+/**
+ * Returns a catalyst Map containing the two arrays in children expressions as keys and values.
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements
+ in keys should not be null""",
+ examples = """
+ Examples:
+ > SELECT _FUNC_([1.0, 3.0], ['2', '4']);
+ {1.0:"2",3.0:"4"}
+ """, since = "2.4.0")
+case class MapFromArrays(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
+
+ override def dataType: DataType = {
+ MapType(
+ keyType = left.dataType.asInstanceOf[ArrayType].elementType,
+ valueType = right.dataType.asInstanceOf[ArrayType].elementType,
+ valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
+ }
+
+ override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
+ val keyArrayData = keyArray.asInstanceOf[ArrayData]
+ val valueArrayData = valueArray.asInstanceOf[ArrayData]
+ if (keyArrayData.numElements != valueArrayData.numElements) {
+ throw new RuntimeException("The given two arrays should have the same length")
+ }
+ val leftArrayType = left.dataType.asInstanceOf[ArrayType]
+ if (leftArrayType.containsNull) {
+ var i = 0
+ while (i < keyArrayData.numElements) {
+ if (keyArrayData.isNullAt(i)) {
+ throw new RuntimeException("Cannot use null as map key!")
+ }
+ i += 1
+ }
+ }
+ new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
+ val arrayBasedMapData = classOf[ArrayBasedMapData].getName
+ val leftArrayType = left.dataType.asInstanceOf[ArrayType]
+ val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
+ val i = ctx.freshName("i")
+ s"""
+ |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
+ | if ($keyArrayData.isNullAt($i)) {
+ | throw new RuntimeException("Cannot use null as map key!");
+ | }
+ |}
+ """.stripMargin
+ }
+ s"""
+ |if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
+ | throw new RuntimeException("The given two arrays should have the same length");
+ |}
+ |$keyArrayElemNullCheck
+ |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
+ """.stripMargin
+ })
+ }
+
+ override def prettyName: String = "map_from_arrays"
+}
+
/**
* An expression representing a not yet available attribute name. This expression is unevaluable
* and as its name suggests it is a temporary place holder until we're able to determine the
@@ -373,12 +444,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
- s"""
+ code"""
|Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
|$values = null;
- """.stripMargin, isNull = "false")
+ """.stripMargin, isNull = FalseLiteral)
}
override def prettyName: String = "named_struct"
@@ -394,7 +465,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
- ExprCode(code = eval.code, isNull = "false", value = eval.value)
+ ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value)
}
override def prettyName: String = "named_struct_unsafe"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 7e53ca3908905..99671d5b863c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
}
"""
} else {
s"""
- ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
"""
}
})
@@ -205,7 +205,7 @@ case class GetArrayStructFields(
} else {
final InternalRow $row = $eval.getStruct($j, $numFields);
$nullSafeEval {
- $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
+ $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)};
}
}
}
@@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(eval1, dataType, index)};
+ ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
})
@@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
}
/**
- * Returns the value of key `key` in Map `child`.
- *
- * We need to do type checking here as `key` expression maybe unresolved.
+ * Common base class for [[GetMapValue]] and [[ElementAt]].
*/
-case class GetMapValue(child: Expression, key: Expression)
- extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {
-
- private def keyType = child.dataType.asInstanceOf[MapType].keyType
-
- // We have done type checking for child in `ExtractValue`, so only need to check the `key`.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
-
- override def toString: String = s"$child[$key]"
- override def sql: String = s"${child.sql}[${key.sql}]"
-
- override def left: Expression = child
- override def right: Expression = key
-
- /** `Null` is returned for invalid ordinals. */
- override def nullable: Boolean = true
-
- override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
+abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
- protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
+ def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
@@ -301,7 +282,7 @@ case class GetMapValue(child: Expression, key: Expression)
var i = 0
var found = false
while (i < length && !found) {
- if (keys.get(i, keyType) == ordinal) {
+ if (ordering.equiv(keys.get(i, keyType), ordinal)) {
found = true
} else {
i += 1
@@ -315,18 +296,20 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
val found = ctx.freshName("found")
val key = ctx.freshName("key")
val values = ctx.freshName("values")
- val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) {
+ val keyType = mapType.keyType
+ val nullCheck = if (mapType.valueContainsNull) {
s" || $values.isNullAt($index)"
} else {
""
}
+ val keyJavaType = CodeGenerator.javaType(keyType)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
@@ -336,7 +319,7 @@ case class GetMapValue(child: Expression, key: Expression)
int $index = 0;
boolean $found = false;
while ($index < $length && !$found) {
- final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)};
+ final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)};
if (${ctx.genEqual(keyType, key, eval2)}) {
$found = true;
} else {
@@ -347,9 +330,54 @@ case class GetMapValue(child: Expression, key: Expression)
if (!$found$nullCheck) {
${ev.isNull} = true;
} else {
- ${ev.value} = ${ctx.getValue(values, dataType, index)};
+ ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
}
"""
})
}
}
+
+/**
+ * Returns the value of key `key` in Map `child`.
+ *
+ * We need to do type checking here as `key` expression maybe unresolved.
+ */
+case class GetMapValue(child: Expression, key: Expression)
+ extends GetMapValueUtil with ExtractValue with NullIntolerant {
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(keyType)
+
+ private def keyType = child.dataType.asInstanceOf[MapType].keyType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
+ }
+ }
+
+ // We have done type checking for child in `ExtractValue`, so only need to check the `key`.
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
+
+ override def toString: String = s"$child[$key]"
+ override def sql: String = s"${child.sql}[${key.sql}]"
+
+ override def left: Expression = child
+ override def right: Expression = key
+
+ /** `Null` is returned for invalid ordinals. */
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
+
+ // todo: current search is O(n), improve it.
+ override def nullSafeEval(value: Any, ordinal: Any): Any = {
+ getValueEval(value, ordinal, keyType, ordering)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index b444c3a7be92a..77ac6c088022e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@@ -66,10 +67,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val falseEval = falseValue.genCode(ctx)
val code =
- s"""
+ code"""
|${condEval.code}
|boolean ${ev.isNull} = false;
- |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${condEval.isNull} && ${condEval.value}) {
| ${trueEval.code}
| ${ev.isNull} = ${trueEval.isNull};
@@ -191,7 +192,9 @@ case class CaseWhen(
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
- ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ ev.value = JavaCode.global(
+ ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value),
+ dataType)
// these blocks are meant to be inside a
// do {
@@ -244,10 +247,10 @@ case class CaseWhen(
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
- returnType = ctx.JAVA_BYTE,
+ returnType = CodeGenerator.JAVA_BYTE,
makeSplitFunction = func =>
s"""
- |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
@@ -263,8 +266,8 @@ case class CaseWhen(
}.mkString)
ev.copy(code =
- s"""
- |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
+ code"""
+ |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes
|} while (false);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 424871f2047e9..08838d2b2c612 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -26,7 +26,8 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -426,36 +427,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
""",
since = "2.3.0")
// scalastyle:on line.size.limit
-case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfWeek(child: Expression) extends DayWeek {
- override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
-
- override def dataType: DataType = IntegerType
+ override protected def nullSafeEval(date: Any): Any = {
+ cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L)
+ cal.get(Calendar.DAY_OF_WEEK)
+ }
- @transient private lazy val c = {
- Calendar.getInstance(DateTimeUtils.getTimeZone("UTC"))
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, time => {
+ val cal = classOf[Calendar].getName
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val c = "calDayOfWeek"
+ ctx.addImmutableStateIfNotExists(cal, c,
+ v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
+ s"""
+ $c.setTimeInMillis($time * 1000L * 3600L * 24L);
+ ${ev.value} = $c.get($cal.DAY_OF_WEEK);
+ """
+ })
}
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('2009-07-30');
+ 3
+ """,
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class WeekDay(child: Expression) extends DayWeek {
override protected def nullSafeEval(date: Any): Any = {
- c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L)
- c.get(Calendar.DAY_OF_WEEK)
+ cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L)
+ (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, time => {
val cal = classOf[Calendar].getName
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- val c = "calDayOfWeek"
+ val c = "calWeekDay"
ctx.addImmutableStateIfNotExists(cal, c,
v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
s"""
$c.setTimeInMillis($time * 1000L * 3600L * 24L);
- ${ev.value} = $c.get($cal.DAY_OF_WEEK);
+ ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7;
"""
})
}
}
+abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
+
+ override def dataType: DataType = IntegerType
+
+ @transient protected lazy val cal: Calendar = {
+ Calendar.getInstance(DateTimeUtils.getTimeZone("UTC"))
+ }
+}
+
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.",
@@ -673,18 +709,19 @@ abstract class UnixTime
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val javaType = CodeGenerator.javaType(dataType)
left.dataType match {
case StringType if right.foldable =>
val df = classOf[DateFormat].getName
if (formatter == null) {
- ExprCode("", "true", ctx.defaultValue(dataType))
+ ExprCode.forNullValue(dataType)
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L;
@@ -710,10 +747,10 @@ abstract class UnixTime
})
case TimestampType =>
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${eval1.value} / 1000000L;
}""")
@@ -721,10 +758,10 @@ abstract class UnixTime
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L;
}""")
@@ -812,14 +849,14 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
val df = classOf[DateFormat].getName
if (format.foldable) {
if (formatter == null) {
- ExprCode("", "true", "(UTF8String) null")
+ ExprCode.forNullValue(StringType)
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = UTF8String.fromString($formatterName.format(
@@ -980,6 +1017,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}
}
+/**
+ * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp,
+ * which requires the timestamp string to not have timezone information, otherwise null is returned.
+ */
+case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+ override def dataType: DataType = TimestampType
+ override def nullable: Boolean = true
+ override def toString: String = child.toString
+ override def sql: String = child.sql
+
+ override def nullSafeEval(input: Any): Any = {
+ DateTimeUtils.stringToTimestamp(
+ input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val tz = ctx.addReferenceObj("timeZone", timeZone)
+ val longOpt = ctx.freshName("longOpt")
+ val eval = child.genCode(ctx)
+ val code = code"""
+ |${eval.code}
+ |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true;
+ |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)};
+ |if (!${eval.isNull}) {
+ | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true);
+ | if ($longOpt.isDefined()) {
+ | ${ev.value} = ((Long) $longOpt.get()).longValue();
+ | ${ev.isNull} = false;
+ | }
+ |}
+ """.stripMargin
+ ev.copy(code = code)
+ }
+}
+
/**
* Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
* that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
@@ -1012,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1026,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1116,42 +1195,61 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
}
/**
- * Returns number of months between dates date1 and date2.
+ * Returns number of months between times `timestamp1` and `timestamp2`.
+ * If `timestamp1` is later than `timestamp2`, then the result is positive.
+ * If `timestamp1` and `timestamp2` are on the same day of month, or both
+ * are the last day of month, time of day will be ignored. Otherwise, the
+ * difference is calculated based on 31 days per month, and rounded to
+ * 8 digits unless roundOff=false.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.",
+ usage = """
+ _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result
+ is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both
+ are the last day of month, time of day will be ignored. Otherwise, the difference is
+ calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false.
+ """,
examples = """
Examples:
> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');
3.94959677
+ > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false);
+ 3.9495967741935485
""",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None)
- extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+case class MonthsBetween(
+ date1: Expression,
+ date2: Expression,
+ roundOff: Expression,
+ timeZoneId: Option[String] = None)
+ extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+
+ def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None)
- def this(date1: Expression, date2: Expression) = this(date1, date2, None)
+ def this(date1: Expression, date2: Expression, roundOff: Expression) =
+ this(date1, date2, roundOff, None)
- override def left: Expression = date1
- override def right: Expression = date2
+ override def children: Seq[Expression] = Seq(date1, date2, roundOff)
- override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType)
override def dataType: DataType = DoubleType
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
- override def nullSafeEval(t1: Any, t2: Any): Any = {
- DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone)
+ override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = {
+ DateTimeUtils.monthsBetween(
+ t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, (l, r) => {
- s"""$dtu.monthsBetween($l, $r, $tz)"""
+ defineCodeGen(ctx, ev, (d1, d2, roundOff) => {
+ s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)"""
})
}
@@ -1190,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1204,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1344,18 +1442,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.$truncFuncStr;
}""")
@@ -1434,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression)
""",
examples = """
Examples:
- > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
- 2015-01-01T00:00:00
- > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
- 2015-03-01T00:00:00
- > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
- 2015-03-05T00:00:00
- > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
- 2015-03-05T09:00:00
+ > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359');
+ 2015-01-01 00:00:00
+ > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359');
+ 2015-03-01 00:00:00
+ > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359');
+ 2015-03-05 00:00:00
+ > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359');
+ 2015-03-05 09:00:00
""",
since = "2.3.0")
// scalastyle:on line.size.limit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 312e04cd297b7..ad7f7dd9434a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.types._
/**
@@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def eval(input: InternalRow): Any = child.eval(input)
/** Just a simple pass-through for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ ev.copy(EmptyBlock)
override def prettyName: String = "promote_precision"
override def sql: String = child.sql
override lazy val canonicalized: Expression = child.canonicalized
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 4f4d49166e88c..b7c52f1d7b40a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -22,7 +22,8 @@ import scala.collection.mutable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@@ -215,10 +216,10 @@ case class Stack(children: Seq[Expression]) extends Generator {
// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ev.copy(code =
- s"""
+ code"""
|$code
|$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
- """.stripMargin, isNull = "false")
+ """.stripMargin, isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 055ebf6c0da54..cec00b66f873c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -28,10 +28,12 @@ import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -269,7 +271,7 @@ abstract class HashExpression[E] extends Expression {
protected def computeHash(value: Any, dataType: DataType, seed: E): E
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ev.isNull = "false"
+ ev.isNull = FalseLiteral
val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
@@ -278,7 +280,7 @@ abstract class HashExpression[E] extends Expression {
}
}
- val hashResultType = ctx.javaType(dataType)
+ val hashResultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
@@ -292,7 +294,7 @@ abstract class HashExpression[E] extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|$hashResultType ${ev.value} = $seed;
|$codes
""".stripMargin)
@@ -307,9 +309,10 @@ abstract class HashExpression[E] extends Expression {
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
+ val jt = CodeGenerator.javaType(elementType)
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
- final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
+ final $jt $element = ${CodeGenerator.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
@@ -359,10 +362,7 @@ abstract class HashExpression[E] extends Expression {
}
protected def genHashString(input: String, result: String): String = {
- val baseObject = s"$input.getBaseObject()"
- val baseOffset = s"$input.getBaseOffset()"
- val numBytes = s"$input.numBytes()"
- s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
+ s"$result = $hasherClassName.hashUTF8String($input, $result);"
}
protected def genHashForMap(
@@ -407,7 +407,7 @@ abstract class HashExpression[E] extends Expression {
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
- val hashResultType = ctx.javaType(dataType)
+ val hashResultType = CodeGenerator.javaType(dataType)
ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
@@ -464,6 +464,8 @@ abstract class InterpretedHashFunction {
protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long
+ protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long
+
/**
* Computes hash of a given `value` of type `dataType`. The caller needs to check the validity
* of input `value`.
@@ -489,8 +491,7 @@ abstract class InterpretedHashFunction {
case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed))
case a: Array[Byte] =>
hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
- case s: UTF8String =>
- hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
+ case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed)
case array: ArrayData =>
val elementType = dataType match {
@@ -577,9 +578,15 @@ object Murmur3HashFunction extends InterpretedHashFunction {
Murmur3_x86_32.hashLong(l, seed.toInt)
}
- override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
+ override protected def hashUnsafeBytes(
+ base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt)
}
+
+ override protected def hashUnsafeBytesBlock(
+ base: MemoryBlock, seed: Long): Long = {
+ Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt)
+ }
}
/**
@@ -604,9 +611,14 @@ object XxHash64Function extends InterpretedHashFunction {
override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed)
- override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
+ override protected def hashUnsafeBytes(
+ base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
XXH64.hashUnsafeBytes(base, offset, len, seed)
}
+
+ override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = {
+ XXH64.hashUnsafeBytesBlock(base, seed)
+ }
}
/**
@@ -632,7 +644,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ev.isNull = "false"
+ ev.isNull = FalseLiteral
val childHash = ctx.freshName("childHash")
val childrenHash = children.map { child =>
@@ -651,11 +663,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
- extraArguments = Seq(ctx.JAVA_INT -> ev.value),
- returnType = ctx.JAVA_INT,
+ extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value),
+ returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
- |${ctx.JAVA_INT} $childHash = 0;
+ |${CodeGenerator.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
@@ -663,9 +675,9 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
- s"""
- |${ctx.JAVA_INT} ${ev.value} = $seed;
- |${ctx.JAVA_INT} $childHash = 0;
+ code"""
+ |${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
+ |${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
""".stripMargin)
}
@@ -713,10 +725,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
"""
override protected def genHashString(input: String, result: String): String = {
- val baseObject = s"$input.getBaseObject()"
- val baseOffset = s"$input.getBaseOffset()"
- val numBytes = s"$input.numBytes()"
- s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);"
+ s"$result = $hasherClassName.hashUTF8String($input);"
}
override protected def genHashForArray(
@@ -780,14 +789,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
""".stripMargin
}
- s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
+ s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
- arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
- returnType = ctx.JAVA_INT,
+ arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
+ returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
- |${ctx.JAVA_INT} $childResult = 0;
+ |${CodeGenerator.JAVA_INT} $childResult = 0;
|$body
|return $result;
""".stripMargin,
@@ -804,10 +813,14 @@ object HiveHashFunction extends InterpretedHashFunction {
HiveHasher.hashLong(l)
}
- override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
+ override protected def hashUnsafeBytes(
+ base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
HiveHasher.hashUnsafeBytes(base, offset, len)
}
+ override protected def hashUnsafeBytesBlock(
+ base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base)
+
private val HIVE_DECIMAL_MAX_PRECISION = 38
private val HIVE_DECIMAL_MAX_SCALE = 38
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
index 7a8edabed1757..3b0141ad52cc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- s"$className.getInputFilePath();", isNull = "false")
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
+ isNull = FalseLiteral)
}
}
@@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- s"$className.getStartOffset();", isNull = "false")
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
}
}
@@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- s"$className.getLength();", isNull = "false")
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 18b4fed597447..f6d74f5b74c8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -513,12 +513,16 @@ case class JsonToStructs(
schema: DataType,
options: Map[String, String],
child: Expression,
- timeZoneId: Option[String] = None)
+ timeZoneId: Option[String],
+ forceNullableSchema: Boolean)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
- override def nullable: Boolean = true
- def this(schema: DataType, options: Map[String, String], child: Expression) =
- this(schema, options, child, None)
+ // The JSON input data might be missing certain fields. We force the nullability
+ // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
+ // can generate incorrect files if values are missing in columns declared as non-nullable.
+ val nullableSchema = if (forceNullableSchema) schema.asNullable else schema
+
+ override def nullable: Boolean = true
// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression) =
@@ -526,35 +530,45 @@ case class JsonToStructs(
schema = JsonExprUtils.validateSchemaLiteral(schema),
options = Map.empty[String, String],
child = child,
- timeZoneId = None)
+ timeZoneId = None,
+ forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
def this(child: Expression, schema: Expression, options: Expression) =
this(
schema = JsonExprUtils.validateSchemaLiteral(schema),
options = JsonExprUtils.convertToMapData(options),
child = child,
- timeZoneId = None)
+ timeZoneId = None,
+ forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
- override def checkInputDataTypes(): TypeCheckResult = schema match {
- case _: StructType | ArrayType(_: StructType, _) =>
+ // Used in `org.apache.spark.sql.functions`
+ def this(schema: DataType, options: Map[String, String], child: Expression) =
+ this(schema, options, child, timeZoneId = None,
+ forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
+
+ override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
+ case _: StructType | ArrayType(_: StructType, _) | _: MapType =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
- s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
+ s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
}
@transient
- lazy val rowSchema = schema match {
+ lazy val rowSchema = nullableSchema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
+ case mt: MapType => mt
}
// This converts parsed rows to the desired output by the given schema.
@transient
- lazy val converter = schema match {
+ lazy val converter = nullableSchema match {
case _: StructType =>
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
case ArrayType(_: StructType, _) =>
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
+ case _: MapType =>
+ (rows: Seq[InternalRow]) => rows.head.getMap(0)
}
@transient
@@ -563,7 +577,7 @@ case class JsonToStructs(
rowSchema,
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get))
- override def dataType: DataType = schema
+ override def dataType: DataType = nullableSchema
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
@@ -601,6 +615,11 @@ case class JsonToStructs(
}
override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+
+ override def sql: String = schema match {
+ case _: MapType => "entries"
+ case _ => super.sql
+ }
}
/**
@@ -727,8 +746,8 @@ case class StructsToJson(
object JsonExprUtils {
- def validateSchemaLiteral(exp: Expression): StructType = exp match {
- case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString)
+ def validateSchemaLiteral(exp: Expression): DataType = exp match {
+ case Literal(s, StringType) => DataType.fromDDL(s.toString)
case e => throw new AnalysisException(s"Expected a string literal instead of $e")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index cd176d941819f..246025b82d59e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -277,41 +277,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
- // change the isNull and primitive to consts, to inline them
+ val javaType = CodeGenerator.javaType(dataType)
if (value == null) {
- ev.isNull = "true"
- ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};")
+ ExprCode.forNullValue(dataType)
} else {
- ev.isNull = "false"
+ def toExprCode(code: String): ExprCode = {
+ ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
+ }
dataType match {
case BooleanType | IntegerType | DateType =>
- ev.copy(code = "", value = value.toString)
+ toExprCode(value.toString)
case FloatType =>
- val v = value.asInstanceOf[Float]
- if (v.isNaN || v.isInfinite) {
- val boxedValue = ctx.addReferenceObj("boxedValue", v)
- val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
- ev.copy(code = code)
- } else {
- ev.copy(code = "", value = s"${value}f")
+ value.asInstanceOf[Float] match {
+ case v if v.isNaN =>
+ toExprCode("Float.NaN")
+ case Float.PositiveInfinity =>
+ toExprCode("Float.POSITIVE_INFINITY")
+ case Float.NegativeInfinity =>
+ toExprCode("Float.NEGATIVE_INFINITY")
+ case _ =>
+ toExprCode(s"${value}F")
}
case DoubleType =>
- val v = value.asInstanceOf[Double]
- if (v.isNaN || v.isInfinite) {
- val boxedValue = ctx.addReferenceObj("boxedValue", v)
- val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;"
- ev.copy(code = code)
- } else {
- ev.copy(code = "", value = s"${value}D")
+ value.asInstanceOf[Double] match {
+ case v if v.isNaN =>
+ toExprCode("Double.NaN")
+ case Double.PositiveInfinity =>
+ toExprCode("Double.POSITIVE_INFINITY")
+ case Double.NegativeInfinity =>
+ toExprCode("Double.NEGATIVE_INFINITY")
+ case _ =>
+ toExprCode(s"${value}D")
}
case ByteType | ShortType =>
- ev.copy(code = "", value = s"($javaType)$value")
+ ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | LongType =>
- ev.copy(code = "", value = s"${value}L")
+ toExprCode(s"${value}L")
case _ =>
- ev.copy(code = "", value = ctx.addReferenceObj("literal", value,
- ctx.javaType(dataType)))
+ val constRef = ctx.addReferenceObj("literal", value, javaType)
+ ExprCode.forNonNullValue(JavaCode.global(constRef, dataType))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
new file mode 100644
index 0000000000000..276a57266a6e0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.commons.codec.digest.DigestUtils
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._
+import org.apache.spark.sql.catalyst.expressions.MaskLike._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+
+trait MaskLike {
+ def upper: String
+ def lower: String
+ def digit: String
+
+ protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase)
+ protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase)
+ protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit)
+
+ protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName
+
+ def inputStringLengthCode(inputString: String, length: String): String = {
+ s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());"
+ }
+
+ def appendMaskedToStringBuilderCode(
+ ctx: CodegenContext,
+ sb: String,
+ inputString: String,
+ offset: String,
+ numChars: String): String = {
+ val i = ctx.freshName("i")
+ val codePoint = ctx.freshName("codePoint")
+ s"""
+ |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
+ | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
+ | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint,
+ | $upperReplacement, $lowerReplacement,
+ | $digitReplacement, $defaultMaskedOther));
+ | $offset += Character.charCount($codePoint);
+ |}
+ """.stripMargin
+ }
+
+ def appendUnchangedToStringBuilderCode(
+ ctx: CodegenContext,
+ sb: String,
+ inputString: String,
+ offset: String,
+ numChars: String): String = {
+ val i = ctx.freshName("i")
+ val codePoint = ctx.freshName("codePoint")
+ s"""
+ |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
+ | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
+ | $sb.appendCodePoint($codePoint);
+ | $offset += Character.charCount($codePoint);
+ |}
+ """.stripMargin
+ }
+
+ def appendMaskedToStringBuilder(
+ sb: java.lang.StringBuilder,
+ inputString: String,
+ startOffset: Int,
+ numChars: Int): Int = {
+ var offset = startOffset
+ (1 to numChars) foreach { _ =>
+ val codePoint = inputString.codePointAt(offset)
+ sb.appendCodePoint(transformChar(
+ codePoint,
+ upperReplacement,
+ lowerReplacement,
+ digitReplacement,
+ defaultMaskedOther))
+ offset += Character.charCount(codePoint)
+ }
+ offset
+ }
+
+ def appendUnchangedToStringBuilder(
+ sb: java.lang.StringBuilder,
+ inputString: String,
+ startOffset: Int,
+ numChars: Int): Int = {
+ var offset = startOffset
+ (1 to numChars) foreach { _ =>
+ val codePoint = inputString.codePointAt(offset)
+ sb.appendCodePoint(codePoint)
+ offset += Character.charCount(codePoint)
+ }
+ offset
+ }
+}
+
+trait MaskLikeWithN extends MaskLike {
+ def n: Int
+ protected lazy val charCount: Int = if (n < 0) 0 else n
+}
+
+/**
+ * Utils for mask operations.
+ */
+object MaskLike {
+ val defaultCharCount = 4
+ val defaultMaskedUppercase: Int = 'X'
+ val defaultMaskedLowercase: Int = 'x'
+ val defaultMaskedDigit: Int = 'n'
+ val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL
+
+ def extractCharCount(e: Expression): Int = e match {
+ case Literal(i, IntegerType | NullType) =>
+ if (i == null) defaultCharCount else i.asInstanceOf[Int]
+ case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " +
+ s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}")
+ case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}")
+ }
+
+ def extractReplacement(e: Expression): String = e match {
+ case Literal(s, StringType | NullType) => if (s == null) null else s.toString
+ case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " +
+ s"${StringType.simpleString}, but got literal of ${dt.simpleString}")
+ case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}")
+ }
+}
+
+/**
+ * Masks the input string. Additional parameters can be set to change the masking chars for
+ * uppercase letters, lowercase letters and digits.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#");
+ llll-UUUU-####-####
+ """)
+// scalastyle:on line.size.limit
+case class Mask(child: Expression, upper: String, lower: String, digit: String)
+ extends UnaryExpression with ExpectsInputTypes with MaskLike {
+
+ def this(child: Expression) = this(child, null.asInstanceOf[String], null, null)
+
+ def this(child: Expression, upper: Expression) =
+ this(child, extractReplacement(upper), null, null)
+
+ def this(child: Expression, upper: Expression, lower: Expression) =
+ this(child, extractReplacement(upper), extractReplacement(lower), null)
+
+ def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) =
+ this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit))
+
+ override def nullSafeEval(input: Any): Any = {
+ val str = input.asInstanceOf[UTF8String].toString
+ val length = str.codePointCount(0, str.length())
+ val sb = new java.lang.StringBuilder(length)
+ appendMaskedToStringBuilder(sb, str, 0, length)
+ UTF8String.fromString(sb.toString)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val sb = ctx.freshName("sb")
+ val length = ctx.freshName("length")
+ val offset = ctx.freshName("offset")
+ val inputString = ctx.freshName("inputString")
+ s"""
+ |String $inputString = $input.toString();
+ |${inputStringLengthCode(inputString, length)}
+ |StringBuilder $sb = new StringBuilder($length);
+ |${CodeGenerator.JAVA_INT} $offset = 0;
+ |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)}
+ |${ev.value} = UTF8String.fromString($sb.toString());
+ """.stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+}
+
+/**
+ * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set
+ * to change the masking chars for uppercase letters, lowercase letters and digits.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("1234-5678-8765-4321", 4);
+ nnnn-5678-8765-4321
+ """)
+// scalastyle:on line.size.limit
+case class MaskFirstN(
+ child: Expression,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String)
+ extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN {
+
+ def this(child: Expression) =
+ this(child, defaultCharCount, null, null, null)
+
+ def this(child: Expression, n: Expression) =
+ this(child, extractCharCount(n), null, null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression, lower: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null)
+
+ def this(
+ child: Expression,
+ n: Expression,
+ upper: Expression,
+ lower: Expression,
+ digit: Expression) =
+ this(child,
+ extractCharCount(n),
+ extractReplacement(upper),
+ extractReplacement(lower),
+ extractReplacement(digit))
+
+ override def nullSafeEval(input: Any): Any = {
+ val str = input.asInstanceOf[UTF8String].toString
+ val length = str.codePointCount(0, str.length())
+ val endOfMask = if (charCount > length) length else charCount
+ val sb = new java.lang.StringBuilder(length)
+ val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask)
+ appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask)
+ UTF8String.fromString(sb.toString)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val sb = ctx.freshName("sb")
+ val length = ctx.freshName("length")
+ val offset = ctx.freshName("offset")
+ val inputString = ctx.freshName("inputString")
+ val endOfMask = ctx.freshName("endOfMask")
+ s"""
+ |String $inputString = $input.toString();
+ |${inputStringLengthCode(inputString, length)}
+ |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount;
+ |${CodeGenerator.JAVA_INT} $offset = 0;
+ |StringBuilder $sb = new StringBuilder($length);
+ |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
+ |${appendUnchangedToStringBuilderCode(
+ ctx, sb, inputString, offset, s"$length - $endOfMask")}
+ |${ev.value} = UTF8String.fromString($sb.toString());
+ |""".stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def prettyName: String = "mask_first_n"
+}
+
+/**
+ * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set
+ * to change the masking chars for uppercase letters, lowercase letters and digits.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("1234-5678-8765-4321", 4);
+ 1234-5678-8765-nnnn
+ """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class MaskLastN(
+ child: Expression,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String)
+ extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN {
+
+ def this(child: Expression) =
+ this(child, defaultCharCount, null, null, null)
+
+ def this(child: Expression, n: Expression) =
+ this(child, extractCharCount(n), null, null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression, lower: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null)
+
+ def this(
+ child: Expression,
+ n: Expression,
+ upper: Expression,
+ lower: Expression,
+ digit: Expression) =
+ this(child,
+ extractCharCount(n),
+ extractReplacement(upper),
+ extractReplacement(lower),
+ extractReplacement(digit))
+
+ override def nullSafeEval(input: Any): Any = {
+ val str = input.asInstanceOf[UTF8String].toString
+ val length = str.codePointCount(0, str.length())
+ val startOfMask = if (charCount >= length) 0 else length - charCount
+ val sb = new java.lang.StringBuilder(length)
+ val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask)
+ appendMaskedToStringBuilder(sb, str, offset, length - startOfMask)
+ UTF8String.fromString(sb.toString)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val sb = ctx.freshName("sb")
+ val length = ctx.freshName("length")
+ val offset = ctx.freshName("offset")
+ val inputString = ctx.freshName("inputString")
+ val startOfMask = ctx.freshName("startOfMask")
+ s"""
+ |String $inputString = $input.toString();
+ |${inputStringLengthCode(inputString, length)}
+ |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ?
+ | 0 : $length - $charCount;
+ |${CodeGenerator.JAVA_INT} $offset = 0;
+ |StringBuilder $sb = new StringBuilder($length);
+ |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
+ |${appendMaskedToStringBuilderCode(
+ ctx, sb, inputString, offset, s"$length - $startOfMask")}
+ |${ev.value} = UTF8String.fromString($sb.toString());
+ |""".stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def prettyName: String = "mask_last_n"
+}
+
+/**
+ * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can
+ * be set to change the masking chars for uppercase letters, lowercase letters and digits.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("1234-5678-8765-4321", 4);
+ 1234-nnnn-nnnn-nnnn
+ """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class MaskShowFirstN(
+ child: Expression,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String)
+ extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN {
+
+ def this(child: Expression) =
+ this(child, defaultCharCount, null, null, null)
+
+ def this(child: Expression, n: Expression) =
+ this(child, extractCharCount(n), null, null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression, lower: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null)
+
+ def this(
+ child: Expression,
+ n: Expression,
+ upper: Expression,
+ lower: Expression,
+ digit: Expression) =
+ this(child,
+ extractCharCount(n),
+ extractReplacement(upper),
+ extractReplacement(lower),
+ extractReplacement(digit))
+
+ override def nullSafeEval(input: Any): Any = {
+ val str = input.asInstanceOf[UTF8String].toString
+ val length = str.codePointCount(0, str.length())
+ val startOfMask = if (charCount > length) length else charCount
+ val sb = new java.lang.StringBuilder(length)
+ val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask)
+ appendMaskedToStringBuilder(sb, str, offset, length - startOfMask)
+ UTF8String.fromString(sb.toString)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val sb = ctx.freshName("sb")
+ val length = ctx.freshName("length")
+ val offset = ctx.freshName("offset")
+ val inputString = ctx.freshName("inputString")
+ val startOfMask = ctx.freshName("startOfMask")
+ s"""
+ |String $inputString = $input.toString();
+ |${inputStringLengthCode(inputString, length)}
+ |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount;
+ |${CodeGenerator.JAVA_INT} $offset = 0;
+ |StringBuilder $sb = new StringBuilder($length);
+ |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
+ |${appendMaskedToStringBuilderCode(
+ ctx, sb, inputString, offset, s"$length - $startOfMask")}
+ |${ev.value} = UTF8String.fromString($sb.toString());
+ |""".stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def prettyName: String = "mask_show_first_n"
+}
+
+/**
+ * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can
+ * be set to change the masking chars for uppercase letters, lowercase letters and digits.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("1234-5678-8765-4321", 4);
+ nnnn-nnnn-nnnn-4321
+ """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class MaskShowLastN(
+ child: Expression,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String)
+ extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN {
+
+ def this(child: Expression) =
+ this(child, defaultCharCount, null, null, null)
+
+ def this(child: Expression, n: Expression) =
+ this(child, extractCharCount(n), null, null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), null, null)
+
+ def this(child: Expression, n: Expression, upper: Expression, lower: Expression) =
+ this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null)
+
+ def this(
+ child: Expression,
+ n: Expression,
+ upper: Expression,
+ lower: Expression,
+ digit: Expression) =
+ this(child,
+ extractCharCount(n),
+ extractReplacement(upper),
+ extractReplacement(lower),
+ extractReplacement(digit))
+
+ override def nullSafeEval(input: Any): Any = {
+ val str = input.asInstanceOf[UTF8String].toString
+ val length = str.codePointCount(0, str.length())
+ val endOfMask = if (charCount >= length) 0 else length - charCount
+ val sb = new java.lang.StringBuilder(length)
+ val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask)
+ appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask)
+ UTF8String.fromString(sb.toString)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val sb = ctx.freshName("sb")
+ val length = ctx.freshName("length")
+ val offset = ctx.freshName("offset")
+ val inputString = ctx.freshName("inputString")
+ val endOfMask = ctx.freshName("endOfMask")
+ s"""
+ |String $inputString = $input.toString();
+ |${inputStringLengthCode(inputString, length)}
+ |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount;
+ |${CodeGenerator.JAVA_INT} $offset = 0;
+ |StringBuilder $sb = new StringBuilder($length);
+ |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
+ |${appendUnchangedToStringBuilderCode(
+ ctx, sb, inputString, offset, s"$length - $endOfMask")}
+ |${ev.value} = UTF8String.fromString($sb.toString());
+ |""".stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def prettyName: String = "mask_show_last_n"
+}
+
+/**
+ * Returns a hashed value based on str.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_("abcd-EFGH-8765-4321");
+ 60c713f5ec6912229d2060df1c322776
+ """)
+// scalastyle:on line.size.limit
+case class MaskHash(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def nullSafeEval(input: Any): Any = {
+ UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString))
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (input: String) => {
+ val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$")
+ s"""
+ |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString()));
+ |""".stripMargin
+ })
+ }
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+ override def prettyName: String = "mask_hash"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 3f11005a5ad1d..bdeb9ed29e0ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -168,9 +169,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.",
+ usage = """
+ _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by
+ `java.lang.Math._FUNC_`.
+ """,
examples = """
Examples:
> SELECT _FUNC_(1);
@@ -178,12 +181,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
> SELECT _FUNC_(2);
NaN
""")
-// scalastyle:on line.size.limit
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.",
+ usage = """
+ _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`,
+ as if computed by `java.lang.Math._FUNC_`.
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -191,18 +195,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS"
> SELECT _FUNC_(2);
NaN
""")
-// scalastyle:on line.size.limit
case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).",
+ usage = """
+ _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by
+ `java.lang.Math._FUNC_`
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""")
-// scalastyle:on line.size.limit
case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
@ExpressionDescription(
@@ -252,7 +256,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the cosine of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns the cosine of `expr`, as if computed by
+ `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - angle in radians
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -261,7 +272,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by
+ `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - hyperbolic angle
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -512,7 +530,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the sine of `expr`.",
+ usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.",
+ arguments = """
+ Arguments:
+ * expr - angle in radians
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -521,7 +543,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S
case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - hyperbolic angle
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -539,7 +567,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH"
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the tangent of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - angle in radians
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -548,7 +582,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT"
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the cotangent of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - angle in radians
+ """,
examples = """
Examples:
> SELECT _FUNC_(1);
@@ -562,7 +602,14 @@ case class Cot(child: Expression)
}
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.",
+ usage = """
+ _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by
+ `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * expr - hyperbolic angle
+ """,
examples = """
Examples:
> SELECT _FUNC_(0);
@@ -572,6 +619,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH"
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts radians to degrees.",
+ arguments = """
+ Arguments:
+ * expr - angle in radians
+ """,
examples = """
Examples:
> SELECT _FUNC_(3.141592653589793);
@@ -583,6 +634,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts degrees to radians.",
+ arguments = """
+ Arguments:
+ * expr - angle in degrees
+ """,
examples = """
Examples:
> SELECT _FUNC_(180);
@@ -768,15 +823,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).",
+ usage = """
+ _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane
+ and the point given by the coordinates (`exprX`, `exprY`), as if computed by
+ `java.lang.Math._FUNC_`.
+ """,
+ arguments = """
+ Arguments:
+ * exprY - coordinate on y-axis
+ * exprX - coordinate on x-axis
+ """,
examples = """
Examples:
> SELECT _FUNC_(0, 0);
0.0
""")
-// scalastyle:on line.size.limit
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {
@@ -1130,15 +1192,16 @@ abstract class RoundBase(child: Expression, scale: Expression,
}"""
}
+ val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}""")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 4b9006ab5b423..5d98dac46cf17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -21,6 +21,8 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -31,7 +33,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
- protected override def nullSafeEval(input: Any): Any = input
+ protected override def nullSafeEval(input: Any): Any = {
+ // scalastyle:off println
+ System.err.println(outputPrefix + input)
+ // scalastyle:on println
+ input
+ }
private val outputPrefix = s"Result of ${child.simpleString} is "
@@ -82,10 +89,11 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- ExprCode(code = s"""${eval.code}
+ ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
- |}""".stripMargin, isNull = "true", value = "null")
+ |}""".stripMargin, isNull = TrueLiteral,
+ value = JavaCode.defaultLiteral(dataType))
}
override def sql: String = s"assert_true(${child.sql})"
@@ -110,25 +118,43 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable {
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.",
+ usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""",
examples = """
Examples:
> SELECT _FUNC_();
46707d92-02f4-4817-8116-a4c3b23e6266
- """)
+ """,
+ note = "The function is non-deterministic.")
// scalastyle:on line.size.limit
-case class Uuid() extends LeafExpression {
+case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful {
+
+ def this() = this(None)
- override lazy val deterministic: Boolean = false
+ override lazy val resolved: Boolean = randomSeed.isDefined
override def nullable: Boolean = false
override def dataType: DataType = StringType
- override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString)
+ @transient private[this] var randomGenerator: RandomUUIDGenerator = _
+
+ override protected def initializeInternal(partitionIndex: Int): Unit =
+ randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex)
+
+ override protected def evalInternal(input: InternalRow): Any =
+ randomGenerator.getNextUUIDUTF8String()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ev.copy(code = s"final UTF8String ${ev.value} = " +
- s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false")
+ val randomGen = ctx.freshName("randomGen")
+ ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen,
+ forceInline = true,
+ useFreshName = false)
+ ctx.addPartitionInitializationStatement(s"$randomGen = " +
+ "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
+ s"${randomSeed.get}L + partitionIndex);")
+ ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
+ isNull = FalseLiteral)
}
+
+ override def freshCopy(): Uuid = Uuid(randomSeed)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 470d5da041ea5..2eeed3bbb2d91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -72,7 +73,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+ ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
@@ -87,14 +88,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
""".stripMargin
}
- val resultType = ctx.javaType(dataType)
+ val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
returnType = resultType,
makeSplitFunction = func =>
s"""
- |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
| $func
|} while (false);
@@ -111,9 +112,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
- |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
+ |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
| $codes
|} while (false);
@@ -232,10 +233,10 @@ case class IsNaN(child: Expression) extends UnaryExpression
val eval = child.genCode(ctx)
child.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false")
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+ ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
}
}
}
@@ -278,10 +279,10 @@ case class NaNvl(left: Expression, right: Expression)
val rightGen = right.genCode(ctx)
left.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${leftGen.code}
boolean ${ev.isNull} = false;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${leftGen.isNull}) {
${ev.isNull} = true;
} else {
@@ -320,7 +321,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
- ExprCode(code = eval.code, isNull = "false", value = eval.isNull)
+ ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull)
}
override def sql: String = s"(${child.sql} IS NULL)"
@@ -346,7 +347,12 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
- ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))")
+ val value = eval.isNull match {
+ case TrueLiteral => FalseLiteral
+ case FalseLiteral => TrueLiteral
+ case v => JavaCode.isNullExpression(s"!$v")
+ }
+ ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
}
override def sql: String = s"(${child.sql} IS NOT NULL)"
@@ -416,8 +422,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "atLeastNNonNulls",
- extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
- returnType = ctx.JAVA_INT,
+ extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil,
+ returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|do {
@@ -435,12 +441,12 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}.mkString)
ev.copy(code =
- s"""
- |${ctx.JAVA_INT} $nonnull = 0;
+ code"""
+ |${CodeGenerator.JAVA_INT} $nonnull = 0;
|do {
| $codes
|} while (false);
- |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
- """.stripMargin, isNull = "false")
+ |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
+ """.stripMargin, isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 64da9bb9cdec1..2bf4203d0fec3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.objects
-import java.lang.reflect.Modifier
+import java.lang.reflect.{Method, Modifier}
+import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
@@ -27,13 +28,16 @@ import scala.util.Try
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.util.Utils
/**
* Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
@@ -59,16 +63,16 @@ trait InvokeLike extends Expression with NonSQLExpression {
* @param ctx a [[CodegenContext]]
* @return (code to prepare arguments, argument string, result of argument null check)
*/
- def prepareArguments(ctx: CodegenContext): (String, String, String) = {
+ def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = {
val resultIsNull = if (needNullCheck) {
- val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull")
- resultIsNull
+ val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull")
+ JavaCode.isNullGlobal(resultIsNull)
} else {
- "false"
+ FalseLiteral
}
val argValues = arguments.map { e =>
- val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue")
+ val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue")
argValue
}
@@ -103,6 +107,93 @@ trait InvokeLike extends Expression with NonSQLExpression {
(argCode, argValues.mkString(", "), resultIsNull)
}
+
+ /**
+ * Evaluate each argument with a given row, invoke a method with a given object and arguments,
+ * and cast a return value if the return type can be mapped to a Java Boxed type
+ *
+ * @param obj the object for the method to be called. If null, perform s static method call
+ * @param method the method object to be called
+ * @param arguments the arguments used for the method call
+ * @param input the row used for evaluating arguments
+ * @param dataType the data type of the return object
+ * @return the return object of a method call
+ */
+ def invoke(
+ obj: Any,
+ method: Method,
+ arguments: Seq[Expression],
+ input: InternalRow,
+ dataType: DataType): Any = {
+ val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
+ if (needNullCheck && args.exists(_ == null)) {
+ // return null if one of arguments is null
+ null
+ } else {
+ val ret = method.invoke(obj, args: _*)
+ val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
+ if (boxedClass.isDefined) {
+ boxedClass.get.cast(ret)
+ } else {
+ ret
+ }
+ }
+ }
+}
+
+/**
+ * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]]
+ */
+trait SerializerSupport {
+ /**
+ * If true, Kryo serialization is used, otherwise the Java one is used
+ */
+ val kryo: Boolean
+
+ /**
+ * The serializer instance to be used for serialization/deserialization in interpreted execution
+ */
+ lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo)
+
+ /**
+ * Adds a immutable state to the generated class containing a reference to the serializer.
+ * @return a string containing the name of the variable referencing the serializer
+ */
+ def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = {
+ val (serializerInstance, serializerInstanceClass) = {
+ if (kryo) {
+ ("kryoSerializer",
+ classOf[KryoSerializerInstance].getName)
+ } else {
+ ("javaSerializer",
+ classOf[JavaSerializerInstance].getName)
+ }
+ }
+ val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer"
+ // Code to initialize the serializer
+ ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v =>
+ s"""
+ |$v = ($serializerInstanceClass) $newSerializerMethod($kryo);
+ """.stripMargin)
+ serializerInstance
+ }
+}
+
+object SerializerSupport {
+ /**
+ * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is
+ * `useKryo` is set to `true`) or a `JavaSerializerInstance`.
+ */
+ def newSerializer(useKryo: Boolean): SerializerInstance = {
+ // try conf from env, otherwise create a new one
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+ val s = if (useKryo) {
+ new KryoSerializer(conf)
+ } else {
+ new JavaSerializer(conf)
+ }
+ s.newInstance()
+ }
}
/**
@@ -129,15 +220,24 @@ case class StaticInvoke(
returnNullable: Boolean = true) extends InvokeLike {
val objectName = staticObject.getName.stripSuffix("$")
+ val cls = if (staticObject.getName == objectName) {
+ staticObject
+ } else {
+ Utils.classForName(objectName)
+ }
override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
+ @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*)
+
+ override def eval(input: InternalRow): Any = {
+ invoke(null, method, arguments, input, dataType)
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@@ -146,12 +246,12 @@ case class StaticInvoke(
val prepareIsNull = if (nullable) {
s"boolean ${ev.isNull} = $resultIsNull;"
} else {
- ev.isNull = "false"
+ ev.isNull = FalseLiteral
""
}
val evaluate = if (returnNullable) {
- if (ctx.defaultValue(dataType) == "null") {
+ if (CodeGenerator.defaultValue(dataType) == "null") {
s"""
${ev.value} = $callFunc;
${ev.isNull} = ${ev.value} == null;
@@ -159,7 +259,7 @@ case class StaticInvoke(
} else {
val boxedResult = ctx.freshName("boxedResult")
s"""
- ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
+ ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc;
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
@@ -170,10 +270,10 @@ case class StaticInvoke(
s"${ev.value} = $callFunc;"
}
- val code = s"""
+ val code = code"""
$argCode
$prepareIsNull
- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!$resultIsNull) {
$evaluate
}
@@ -208,12 +308,11 @@ case class Invoke(
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
+ lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
+
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
-
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
@transient lazy val method = targetObject.dataType match {
@@ -227,8 +326,23 @@ case class Invoke(
case _ => None
}
+ override def eval(input: InternalRow): Any = {
+ val obj = targetObject.eval(input)
+ if (obj == null) {
+ // return null if obj is null
+ null
+ } else {
+ val invokeMethod = if (method.isDefined) {
+ method.get
+ } else {
+ obj.getClass.getDeclaredMethod(functionName, argClasses: _*)
+ }
+ invoke(obj, invokeMethod, arguments, input, dataType)
+ }
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val obj = targetObject.genCode(ctx)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@@ -255,11 +369,11 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still
// set correctly.
val assignResult = if (!returnNullable) {
- s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
+ s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
} else {
s"""
if ($funcResult != null) {
- ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
+ ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
} else {
${ev.isNull} = true;
}
@@ -272,10 +386,9 @@ case class Invoke(
"""
}
- val code = s"""
- ${obj.code}
+ val code = obj.code + code"""
boolean ${ev.isNull} = true;
- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${obj.isNull}) {
$argCode
${ev.isNull} = $resultIsNull;
@@ -337,11 +450,35 @@ case class NewInstance(
childrenResolved && !needOuterPointer
}
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ @transient private lazy val constructor: (Seq[AnyRef]) => Any = {
+ val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
+ val getConstructor = (paramClazz: Seq[Class[_]]) => {
+ ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
+ sys.error(s"Couldn't find a valid constructor on $cls")
+ }
+ }
+ outerPointer.map { p =>
+ val outerObj = p()
+ val d = outerObj.getClass +: paramTypes
+ val c = getConstructor(outerObj.getClass +: paramTypes)
+ (args: Seq[AnyRef]) => {
+ c.newInstance(outerObj +: args: _*)
+ }
+ }.getOrElse {
+ val c = getConstructor(paramTypes)
+ (args: Seq[AnyRef]) => {
+ c.newInstance(args: _*)
+ }
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val argValues = arguments.map(_.eval(input))
+ constructor(argValues.map(_.asInstanceOf[AnyRef]))
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@@ -355,10 +492,11 @@ case class NewInstance(
s"new $className($argString)"
}
- val code = s"""
+ val code = code"""
$argCode
${outer.map(_.code).getOrElse("")}
- final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
+ final $javaType ${ev.value} = ${ev.isNull} ?
+ ${CodeGenerator.defaultValue(dataType)} : $constructorCall;
"""
ev.copy(code = code)
}
@@ -381,19 +519,23 @@ case class UnwrapOption(
override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ override def eval(input: InternalRow): Any = {
+ val inputObject = child.eval(input)
+ if (inputObject == null) {
+ null
+ } else {
+ inputObject.asInstanceOf[Option[_]].orNull
+ }
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
- $javaType ${ev.value} = ${ev.isNull} ?
- ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
+ $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
+ (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
"""
ev.copy(code = code)
}
@@ -415,20 +557,17 @@ case class WrapOption(child: Expression, optType: DataType)
override def inputTypes: Seq[AbstractDataType] = optType :: Nil
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ override def eval(input: InternalRow): Any = Option(child.eval(input))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
"""
- ev.copy(code = code, isNull = "false")
+ ev.copy(code = code, isNull = FalseLiteral)
}
}
@@ -440,12 +579,33 @@ case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
- nullable: Boolean = true) extends LeafExpression
- with Unevaluable with NonSQLExpression {
+ nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
+
+ private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
+
+ // Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
+ override def eval(input: InternalRow): Any = {
+ assert(input.numFields == 1,
+ "The input row of interpreted LambdaVariable should have only 1 field.")
+ if (nullable && input.isNullAt(0)) {
+ null
+ } else {
+ accessor(input, 0)
+ }
+ }
override def genCode(ctx: CodegenContext): ExprCode = {
- ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
+ val isNullValue = if (nullable) {
+ JavaCode.isNullVariable(isNull)
+ } else {
+ FalseLiteral
+ }
+ ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue)
}
+
+ // This won't be called as `genCode` is overrided, just overriding it to make
+ // `LambdaVariable` non-abstract.
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}
/**
@@ -538,15 +698,99 @@ case class MapObjects private(
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ // The data with UserDefinedType are actually stored with the data type of its sqlType.
+ // When we want to apply MapObjects on it, we have to use it.
+ lazy private val inputDataType = inputData.dataType match {
+ case u: UserDefinedType[_] => u.sqlType
+ case _ => inputData.dataType
+ }
+
+ private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
+ val row = new GenericInternalRow(1)
+ inputCollection.toIterator.map { element =>
+ row.update(0, element)
+ lambdaFunction.eval(row)
+ }
+ }
+
+ private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
+ case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ _.asInstanceOf[Seq[_]]
+ case ObjectType(cls) if cls.isArray =>
+ _.asInstanceOf[Array[_]].toSeq
+ case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ _.asInstanceOf[java.util.List[_]].asScala
+ case ObjectType(cls) if cls == classOf[Object] =>
+ (inputCollection) => {
+ if (inputCollection.getClass.isArray) {
+ inputCollection.asInstanceOf[Array[_]].toSeq
+ } else {
+ inputCollection.asInstanceOf[Seq[_]]
+ }
+ }
+ case ArrayType(et, _) =>
+ _.asInstanceOf[ArrayData].toSeq[Any](et)
+ }
+
+ private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
+ case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ // Scala sequence
+ executeFuncOnCollection(_).toSeq
+ case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
+ // Scala set
+ executeFuncOnCollection(_).toSet
+ case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ // Java list
+ if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
+ cls == classOf[java.util.AbstractSequentialList[_]]) {
+ // Specifying non concrete implementations of `java.util.List`
+ executeFuncOnCollection(_).toSeq.asJava
+ } else {
+ val constructors = cls.getConstructors()
+ val intParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
+ }
+ val noParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 0
+ }
+
+ val constructor = intParamConstructor.map { intConstructor =>
+ (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
+ }.getOrElse {
+ (_: Int) => noParamConstructor.get.newInstance()
+ }
+
+ // Specifying concrete implementations of `java.util.List`
+ (inputs) => {
+ val results = executeFuncOnCollection(inputs)
+ val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
+ results.foreach(builder.add(_))
+ builder
+ }
+ }
+ case None =>
+ // array
+ x => new GenericArrayData(executeFuncOnCollection(x).toArray)
+ case Some(cls) =>
+ throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
+ "resulting collection.")
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val inputCollection = inputData.eval(input)
+
+ if (inputCollection == null) {
+ return null
+ }
+ mapElements(convertToSeq(inputCollection))
+ }
override def dataType: DataType =
customCollectionCls.map(ObjectType.apply).getOrElse(
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val elementJavaType = ctx.javaType(loopVarDataType)
+ val elementJavaType = CodeGenerator.javaType(loopVarDataType)
ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false)
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
@@ -554,7 +798,7 @@ case class MapObjects private(
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
- val convertedType = ctx.boxedType(lambdaFunction.dataType)
+ val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -586,13 +830,6 @@ case class MapObjects private(
case _ => ""
}
- // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
- // When we want to apply MapObjects on it, we have to use it.
- val inputDataType = inputData.dataType match {
- case p: PythonUserDefinedType => p.sqlType
- case _ => inputData.dataType
- }
-
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
@@ -621,7 +858,7 @@ case class MapObjects private(
(
s"${genInputData.value}.numElements()",
"",
- ctx.getValue(genInputData.value, et, loopIndex)
+ CodeGenerator.getValue(genInputData.value, et, loopIndex)
)
case ObjectType(cls) if cls == classOf[Object] =>
val it = ctx.freshName("it")
@@ -635,7 +872,7 @@ case class MapObjects private(
// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
- val genFunctionValue = lambdaFunction.dataType match {
+ val genFunctionValue: String = lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
@@ -643,7 +880,8 @@ case class MapObjects private(
}
val loopNullCheck = if (loopIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
+ ctx.addMutableState(
+ CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
case _ => s"$loopIsNull = $loopValue == null;"
@@ -693,9 +931,8 @@ case class MapObjects private(
)
}
- val code = s"""
- ${genInputData.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ val code = genInputData.code + code"""
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
$determineCollectionType
@@ -792,8 +1029,41 @@ case class CatalystToExternalMap private(
override def children: Seq[Expression] =
keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType]
+
+ private lazy val keyConverter =
+ CatalystTypeConverters.createToScalaConverter(inputMapType.keyType)
+ private lazy val valueConverter =
+ CatalystTypeConverters.createToScalaConverter(inputMapType.valueType)
+
+ private lazy val (newMapBuilderMethod, moduleField) = {
+ val clazz = Utils.classForName(collClass.getCanonicalName + "$")
+ (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null))
+ }
+
+ private def newMapBuilder(): Builder[AnyRef, AnyRef] = {
+ newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]]
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val result = inputData.eval(input).asInstanceOf[MapData]
+ if (result != null) {
+ val builder = newMapBuilder()
+ builder.sizeHint(result.numElements())
+ val keyArray = result.keyArray()
+ val valueArray = result.valueArray()
+ var i = 0
+ while (i < result.numElements()) {
+ val key = keyConverter(keyArray.get(i, inputMapType.keyType))
+ val value = valueConverter(valueArray.get(i, inputMapType.valueType))
+ builder += Tuple2(key, value)
+ i += 1
+ }
+ builder.result()
+ } else {
+ null
+ }
+ }
override def dataType: DataType = ObjectType(collClass)
@@ -806,10 +1076,10 @@ case class CatalystToExternalMap private(
}
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
- val keyElementJavaType = ctx.javaType(mapType.keyType)
+ val keyElementJavaType = CodeGenerator.javaType(mapType.keyType)
ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false)
val genKeyFunction = keyLambdaFunction.genCode(ctx)
- val valueElementJavaType = ctx.javaType(mapType.valueType)
+ val valueElementJavaType = CodeGenerator.javaType(mapType.valueType)
ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true,
useFreshName = false)
val genValueFunction = valueLambdaFunction.genCode(ctx)
@@ -825,10 +1095,11 @@ case class CatalystToExternalMap private(
val valueArray = ctx.freshName("valueArray")
val getKeyArray =
s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
- val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
+ val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
val getValueArray =
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
- val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
+ val getValueLoopVar = CodeGenerator.getValue(
+ valueArray, inputDataType(mapType.valueType), loopIndex)
// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
@@ -844,7 +1115,7 @@ case class CatalystToExternalMap private(
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
val valueLoopNullCheck = if (valueLoopIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
+ ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
useFreshName = false)
s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
} else {
@@ -871,9 +1142,8 @@ case class CatalystToExternalMap private(
"""
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
- val code = s"""
- ${genInputData.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ val code = genInputData.code + code"""
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
int $dataLength = $getLength;
@@ -979,8 +1249,72 @@ case class ExternalMapToCatalyst private(
override def dataType: MapType = MapType(
keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable)
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = {
+ val rowBuffer = InternalRow.fromSeq(Array[Any](1))
+ def rowWrapper(data: Any): InternalRow = {
+ rowBuffer.update(0, data)
+ rowBuffer
+ }
+
+ child.dataType match {
+ case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
+ (input: Any) => {
+ val data = input.asInstanceOf[java.util.Map[Any, Any]]
+ val keys = new Array[Any](data.size)
+ val values = new Array[Any](data.size)
+ val iter = data.entrySet().iterator()
+ var i = 0
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val (key, value) = (entry.getKey, entry.getValue)
+ keys(i) = if (key != null) {
+ keyConverter.eval(rowWrapper(key))
+ } else {
+ throw new RuntimeException("Cannot use null as map key!")
+ }
+ values(i) = if (value != null) {
+ valueConverter.eval(rowWrapper(value))
+ } else {
+ null
+ }
+ i += 1
+ }
+ (keys, values)
+ }
+
+ case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
+ (input: Any) => {
+ val data = input.asInstanceOf[scala.collection.Map[Any, Any]]
+ val keys = new Array[Any](data.size)
+ val values = new Array[Any](data.size)
+ var i = 0
+ for ((key, value) <- data) {
+ keys(i) = if (key != null) {
+ keyConverter.eval(rowWrapper(key))
+ } else {
+ throw new RuntimeException("Cannot use null as map key!")
+ }
+ values(i) = if (value != null) {
+ valueConverter.eval(rowWrapper(value))
+ } else {
+ null
+ }
+ i += 1
+ }
+ (keys, values)
+ }
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val result = child.eval(input)
+ if (result != null) {
+ val (keys, values) = mapCatalystConverter(result)
+ new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
+ } else {
+ null
+ }
+ }
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputMap = child.genCode(ctx)
@@ -993,8 +1327,8 @@ case class ExternalMapToCatalyst private(
val entry = ctx.freshName("entry")
val entries = ctx.freshName("entries")
- val keyElementJavaType = ctx.javaType(keyType)
- val valueElementJavaType = ctx.javaType(valueType)
+ val keyElementJavaType = CodeGenerator.javaType(keyType)
+ val valueElementJavaType = CodeGenerator.javaType(valueType)
ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false)
ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false)
@@ -1009,8 +1343,8 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
- $key = (${ctx.boxedType(keyType)}) $entry.getKey();
- $value = (${ctx.boxedType(valueType)}) $entry.getValue();
+ $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey();
+ $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue();
"""
defineEntries -> defineKeyValue
@@ -1024,22 +1358,24 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
- $key = (${ctx.boxedType(keyType)}) $entry._1();
- $value = (${ctx.boxedType(valueType)}) $entry._2();
+ $key = (${CodeGenerator.boxedType(keyType)}) $entry._1();
+ $value = (${CodeGenerator.boxedType(valueType)}) $entry._2();
"""
defineEntries -> defineKeyValue
}
val keyNullCheck = if (keyIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
+ ctx.addMutableState(
+ CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
s"$keyIsNull = $key == null;"
} else {
""
}
val valueNullCheck = if (valueIsNull != "false") {
- ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
+ ctx.addMutableState(
+ CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
s"$valueIsNull = $value == null;"
} else {
""
@@ -1047,12 +1383,11 @@ case class ExternalMapToCatalyst private(
val arrayCls = classOf[GenericArrayData].getName
val mapCls = classOf[ArrayBasedMapData].getName
- val convertedKeyType = ctx.boxedType(keyConverter.dataType)
- val convertedValueType = ctx.boxedType(valueConverter.dataType)
- val code =
- s"""
- ${inputMap.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
+ val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
+ val code = inputMap.code +
+ code"""
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${inputMap.isNull}) {
final int $length = ${inputMap.value}.size();
final Object[] $convertedKeys = new Object[$length];
@@ -1101,8 +1436,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
override def nullable: Boolean = false
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ override def eval(input: InternalRow): Any = {
+ val values = children.map(_.eval(input)).toArray
+ new GenericRowWithSchema(values, schema)
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rowClass = classOf[GenericRowWithSchema].getName
@@ -1127,12 +1464,12 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val schemaField = ctx.addReferenceObj("schema", schema)
val code =
- s"""
+ code"""
|Object[] $values = new Object[${children.size}];
|$childrenCode
|final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
""".stripMargin
- ev.copy(code = code, isNull = "false")
+ ev.copy(code = code, isNull = FalseLiteral)
}
}
@@ -1142,44 +1479,22 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
- extends UnaryExpression with NonSQLExpression {
+ extends UnaryExpression with NonSQLExpression with SerializerSupport {
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ override def nullSafeEval(input: Any): Any = {
+ serializerInstance.serialize(input).array()
+ }
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // Code to initialize the serializer.
- val (serializer, serializerClass, serializerInstanceClass) = {
- if (kryo) {
- ("kryoSerializer",
- classOf[KryoSerializer].getName,
- classOf[KryoSerializerInstance].getName)
- } else {
- ("javaSerializer",
- classOf[JavaSerializer].getName,
- classOf[JavaSerializerInstance].getName)
- }
- }
- // try conf from env, otherwise create a new one
- val env = s"${classOf[SparkEnv].getName}.get()"
- val sparkConf = s"new ${classOf[SparkConf].getName}()"
- ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
- s"""
- |if ($env == null) {
- | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
- |} else {
- | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
- |}
- """.stripMargin)
-
+ val serializer = addImmutableSerializerIfNeeded(ctx)
// Code to serialize.
val input = child.genCode(ctx)
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
- val code = s"""
- ${input.code}
- final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
+ val code = input.code + code"""
+ final $javaType ${ev.value} =
+ ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
@@ -1194,42 +1509,24 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
* @param kryo if true, use Kryo. Otherwise, use Java.
*/
case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean)
- extends UnaryExpression with NonSQLExpression {
+ extends UnaryExpression with NonSQLExpression with SerializerSupport {
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // Code to initialize the serializer.
- val (serializer, serializerClass, serializerInstanceClass) = {
- if (kryo) {
- ("kryoSerializer",
- classOf[KryoSerializer].getName,
- classOf[KryoSerializerInstance].getName)
- } else {
- ("javaSerializer",
- classOf[JavaSerializer].getName,
- classOf[JavaSerializerInstance].getName)
- }
- }
- // try conf from env, otherwise create a new one
- val env = s"${classOf[SparkEnv].getName}.get()"
- val sparkConf = s"new ${classOf[SparkConf].getName}()"
- ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
- s"""
- |if ($env == null) {
- | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
- |} else {
- | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance();
- |}
- """.stripMargin)
+ override def nullSafeEval(input: Any): Any = {
+ val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]])
+ serializerInstance.deserialize(inputBytes)
+ }
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val serializer = addImmutableSerializerIfNeeded(ctx)
// Code to deserialize.
val input = child.genCode(ctx)
- val javaType = ctx.javaType(dataType)
+ val javaType = CodeGenerator.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
- val code = s"""
- ${input.code}
- final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
+ val code = input.code + code"""
+ final $javaType ${ev.value} =
+ ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
@@ -1247,21 +1544,60 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
override def dataType: DataType = beanInstance.dataType
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ private lazy val resolvedSetters = {
+ assert(beanInstance.dataType.isInstanceOf[ObjectType])
+
+ val ObjectType(beanClass) = beanInstance.dataType
+ setters.map {
+ case (name, expr) =>
+ // Looking for known type mapping.
+ // But also looking for general `Object`-type parameter for generic methods.
+ val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object])
+ val methods = paramTypes.flatMap { fieldClass =>
+ try {
+ Some(beanClass.getDeclaredMethod(name, fieldClass))
+ } catch {
+ case e: NoSuchMethodException => None
+ }
+ }
+ if (methods.isEmpty) {
+ throw new NoSuchMethodException(s"""A method named "$name" is not declared """ +
+ "in any enclosing class nor any supertype")
+ }
+ methods.head -> expr
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val instance = beanInstance.eval(input)
+ if (instance != null) {
+ val bean = instance.asInstanceOf[Object]
+ resolvedSetters.foreach {
+ case (setter, expr) =>
+ val paramVal = expr.eval(input)
+ // We don't call setter if input value is null.
+ if (paramVal != null) {
+ setter.invoke(bean, paramVal.asInstanceOf[AnyRef])
+ }
+ }
+ }
+ instance
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val instanceGen = beanInstance.genCode(ctx)
val javaBeanInstance = ctx.freshName("javaBean")
- val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
+ val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType)
val initialize = setters.map {
case (setterMethod, fieldValue) =>
val fieldGen = fieldValue.genCode(ctx)
s"""
|${fieldGen.code}
- |$javaBeanInstance.$setterMethod(${fieldGen.value});
+ |if (!${fieldGen.isNull}) {
+ | $javaBeanInstance.$setterMethod(${fieldGen.value});
+ |}
""".stripMargin
}
val initializeCode = ctx.splitExpressionsWithCurrentInputs(
@@ -1269,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
funcName = "initializeJavaBean",
extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
- val code =
- s"""
- |${instanceGen.code}
+ val code = instanceGen.code +
+ code"""
|$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
|if (!${instanceGen.isNull}) {
| $initializeCode
@@ -1319,14 +1654,12 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
// because errMsgField is used only when the value is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- val code = s"""
- ${childGen.code}
-
+ val code = childGen.code + code"""
if (${childGen.isNull}) {
throw new NullPointerException($errMsgField);
}
"""
- ev.copy(code = code, isNull = "false", value = childGen.value)
+ ev.copy(code = code, isNull = FalseLiteral, value = childGen.value)
}
}
@@ -1346,17 +1679,25 @@ case class GetExternalRowField(
override def dataType: DataType = ObjectType(classOf[Object])
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
-
private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null."
+ override def eval(input: InternalRow): Any = {
+ val inputRow = child.eval(input).asInstanceOf[Row]
+ if (inputRow == null) {
+ throw new RuntimeException("The input external row cannot be null.")
+ }
+ if (inputRow.isNullAt(index)) {
+ throw new RuntimeException(errMsg)
+ }
+ inputRow.get(index)
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the field is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
- val code = s"""
+ val code = code"""
${row.code}
if (${row.isNull}) {
@@ -1369,7 +1710,7 @@ case class GetExternalRowField(
final Object ${ev.value} = ${row.value}.get($index);
"""
- ev.copy(code = code, isNull = "false")
+ ev.copy(code = code, isNull = FalseLiteral)
}
}
@@ -1384,13 +1725,36 @@ case class ValidateExternalType(child: Expression, expected: DataType)
override def nullable: Boolean = child.nullable
- override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
-
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"
+ private lazy val checkType: (Any) => Boolean = expected match {
+ case _: DecimalType =>
+ (value: Any) => {
+ value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] ||
+ value.isInstanceOf[Decimal]
+ }
+ case _: ArrayType =>
+ (value: Any) => {
+ value.getClass.isArray || value.isInstanceOf[Seq[_]]
+ }
+ case _ =>
+ val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
+ (value: Any) => {
+ dataTypeClazz.isInstance(value)
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val result = child.eval(input)
+ if (checkType(result)) {
+ result
+ } else {
+ throw new RuntimeException(s"${result.getClass.getName}$errMsg")
+ }
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the type doesn't match.
@@ -1403,17 +1767,17 @@ case class ValidateExternalType(child: Expression, expected: DataType)
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
case _: ArrayType =>
- s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
+ s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}"
case _ =>
- s"$obj instanceof ${ctx.boxedType(dataType)}"
+ s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
- val code = s"""
+ val code = code"""
${input.code}
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {
if ($typeCheck) {
- ${ev.value} = (${ctx.boxedType(dataType)}) $obj;
+ ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
} else {
throw new RuntimeException($obj.getClass().getName() + $errMsgField);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1a48995358af7..8a06daa37132d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.catalyst
+import java.util.Locale
+
import com.google.common.collect.Maps
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -138,6 +142,88 @@ package object expressions {
def indexOf(exprId: ExprId): Int = {
Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
}
+
+ private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = {
+ m.mapValues(_.distinct).map(identity)
+ }
+
+ /** Map to use for direct case insensitive attribute lookups. */
+ @transient private lazy val direct: Map[String, Seq[Attribute]] = {
+ unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT)))
+ }
+
+ /** Map to use for qualified case insensitive attribute lookups. */
+ @transient private val qualified: Map[(String, String), Seq[Attribute]] = {
+ val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a =>
+ (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT))
+ }
+ unique(grouped)
+ }
+
+ /** Perform attribute resolution given a name and a resolver. */
+ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
+ // Collect matching attributes given a name and a lookup.
+ def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
+ candidates.toSeq.flatMap(_.collect {
+ case a if resolver(a.name, name) => a.withName(name)
+ })
+ }
+
+ // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name,
+ // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of
+ // matched attributes and a list of parts that are to be resolved.
+ //
+ // For example, consider an example where "a" is the table name, "b" is the column name,
+ // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
+ // and the second element will be List("c").
+ val matches = nameParts match {
+ case qualifier +: name +: nestedFields =>
+ val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT))
+ val attributes = collectMatches(name, qualified.get(key)).filter { a =>
+ resolver(qualifier, a.qualifier.get)
+ }
+ (attributes, nestedFields)
+ case all =>
+ (Nil, all)
+ }
+
+ // If none of attributes match `table.column` pattern, we try to resolve it as a column.
+ val (candidates, nestedFields) = matches match {
+ case (Seq(), _) =>
+ val name = nameParts.head
+ val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
+ (attributes, nameParts.tail)
+ case _ => matches
+ }
+
+ def name = UnresolvedAttribute(nameParts).name
+ candidates match {
+ case Seq(a) if nestedFields.nonEmpty =>
+ // One match, but we also need to extract the requested nested field.
+ // The foldLeft adds ExtractValues for every remaining parts of the identifier,
+ // and aliased it with the last part of the name.
+ // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
+ // expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) =>
+ ExtractValue(e, Literal(name), resolver)
+ }
+ Some(Alias(fieldExprs, nestedFields.last)())
+
+ case Seq(a) =>
+ // One match, no nested fields, use it.
+ Some(a)
+
+ case Seq() =>
+ // No matches.
+ None
+
+ case ambiguousReferences =>
+ // More than one match.
+ val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ")
+ throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.")
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b469f5cb7586a..f54103c4fbfba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -36,6 +37,14 @@ object InterpretedPredicate {
case class InterpretedPredicate(expression: Expression) extends BasePredicate {
override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean]
+
+ override def initialize(partitionIndex: Int): Unit = {
+ super.initialize(partitionIndex)
+ expression.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ }
+ }
}
/**
@@ -157,7 +166,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
require(list != null, "list should not be null")
override def checkInputDataTypes(): TypeCheckResult = {
- val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
+ val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
+ ignoreNullability = true))
if (mismatchOpt.isDefined) {
list match {
case ListQuery(_, _, _, childOutputs) :: Nil =>
@@ -234,7 +244,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaDataType = ctx.javaType(value.dataType)
+ val javaDataType = CodeGenerator.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
// inTmpResult has 3 possible values:
@@ -262,8 +272,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
- extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
- returnType = ctx.JAVA_BYTE,
+ extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil,
+ returnType = CodeGenerator.JAVA_BYTE,
makeSplitFunction = body =>
s"""
|do {
@@ -281,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}.mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
@@ -345,10 +355,10 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
""
}
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
- |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
- |${ctx.JAVA_BOOLEAN} ${ev.value} = false;
+ |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
+ |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|if (!${ev.isNull}) {
| ${ev.value} = $setTerm.contains(${childGen.value});
| $setIsNull
@@ -397,16 +407,16 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
// The result should be `false`, if any of them is `false` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = false;
if (${eval1.value}) {
${eval2.code}
${ev.value} = ${eval2.value};
- }""", isNull = "false")
+ }""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = false;
@@ -460,17 +470,17 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
// The result should be `true`, if any of them is `true` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
- ev.isNull = "false"
- ev.copy(code = s"""
+ ev.isNull = FalseLiteral
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = true;
if (!${eval1.value}) {
${eval2.code}
${ev.value} = ${eval2.value};
- }""", isNull = "false")
+ }""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = true;
@@ -504,7 +514,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (ctx.isPrimitiveType(left.dataType)
+ if (CodeGenerator.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
&& left.dataType != FloatType
&& left.dataType != DoubleType) {
@@ -612,9 +622,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value)
- ev.copy(code = eval1.code + eval2.code + s"""
+ ev.copy(code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
- (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false")
+ (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 8bc936fcbfc31..926c2f00d430d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -31,7 +32,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
@@ -68,7 +69,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm
0.8446490682263027
> SELECT _FUNC_(null);
0.8446490682263027
- """)
+ """,
+ note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Rand(child: Expression) extends RDG {
@@ -81,9 +83,12 @@ case class Rand(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
- final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
+ ev.copy(code = code"""
+ final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
+ isNull = FalseLiteral)
}
+
+ override def freshCopy(): Rand = Rand(child)
}
object Rand {
@@ -93,7 +98,7 @@ object Rand {
/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.",
+ usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""",
examples = """
Examples:
> SELECT _FUNC_();
@@ -102,7 +107,8 @@ object Rand {
1.1164209726833079
> SELECT _FUNC_(null);
1.1164209726833079
- """)
+ """,
+ note = "The function is non-deterministic in general case.")
// scalastyle:on line.size.limit
case class Randn(child: Expression) extends RDG {
@@ -115,9 +121,12 @@ case class Randn(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
- final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
+ ev.copy(code = code"""
+ final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
+ isNull = FalseLiteral)
}
+
+ override def freshCopy(): Randn = Randn(child)
}
object Randn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index f3e8f6de58975..7b68bb771faf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -123,18 +124,18 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
}
} else {
@@ -198,18 +199,18 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0);
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
}
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index d7612e30b4c57..bedad7da334ae 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -36,87 +37,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
////////////////////////////////////////////////////////////////////////////////////////////////////
-/**
- * An expression that concatenates multiple inputs into a single output.
- * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
- * If any input is null, concat returns null.
- */
-@ExpressionDescription(
- usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.",
- examples = """
- Examples:
- > SELECT _FUNC_('Spark', 'SQL');
- SparkSQL
- """)
-case class Concat(children: Seq[Expression]) extends Expression {
-
- private lazy val isBinaryMode: Boolean = dataType == BinaryType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (children.isEmpty) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- val childTypes = children.map(_.dataType)
- if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
- return TypeCheckResult.TypeCheckFailure(
- s"input to function $prettyName should have StringType or BinaryType, but it's " +
- childTypes.map(_.simpleString).mkString("[", ", ", "]"))
- }
- TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
- }
- }
-
- override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
-
- override def nullable: Boolean = children.exists(_.nullable)
- override def foldable: Boolean = children.forall(_.foldable)
-
- override def eval(input: InternalRow): Any = {
- if (isBinaryMode) {
- val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
- ByteArray.concat(inputs: _*)
- } else {
- val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
- UTF8String.concat(inputs : _*)
- }
- }
-
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val evals = children.map(_.genCode(ctx))
- val args = ctx.freshName("args")
-
- val inputs = evals.zipWithIndex.map { case (eval, index) =>
- s"""
- ${eval.code}
- if (!${eval.isNull}) {
- $args[$index] = ${eval.value};
- }
- """
- }
-
- val (concatenator, initCode) = if (isBinaryMode) {
- (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
- } else {
- ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
- }
- val codes = ctx.splitExpressionsWithCurrentInputs(
- expressions = inputs,
- funcName = "valueConcat",
- extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
- ev.copy(s"""
- $initCode
- $codes
- ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
- boolean ${ev.isNull} = ${ev.value} == null;
- """)
- }
-
- override def toString: String = s"concat(${children.mkString(", ")})"
-
- override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
-}
-
-
/**
* An expression that concatenates multiple input strings or array of strings into a single string,
* using a given separator (the first child).
@@ -186,7 +106,7 @@ case class ConcatWs(children: Seq[Expression])
expressions = inputs,
funcName = "valueConcatWs",
extraArguments = ("UTF8String[]", args) :: Nil)
- ev.copy(s"""
+ ev.copy(code"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
@@ -196,7 +116,7 @@ case class ConcatWs(children: Seq[Expression])
} else {
val array = ctx.freshName("array")
val varargNum = ctx.freshName("varargNum")
- val idxInVararg = ctx.freshName("idxInVararg")
+ val idxVararg = ctx.freshName("idxInVararg")
val evals = children.map(_.genCode(ctx))
val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) =>
@@ -206,7 +126,7 @@ case class ConcatWs(children: Seq[Expression])
if (eval.isNull == "true") {
""
} else {
- s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
+ s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
})
case _: ArrayType =>
val size = ctx.freshName("n")
@@ -222,7 +142,7 @@ case class ConcatWs(children: Seq[Expression])
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
- $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
+ $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")};
}
}
""")
@@ -230,7 +150,7 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
- val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
+ val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString))
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
@@ -247,20 +167,20 @@ case class ConcatWs(children: Seq[Expression])
val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
expressions = varargBuild,
funcName = "varargBuildsConcatWs",
- extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
+ extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
|$body
- |return $idxInVararg;
+ |return $idxVararg;
""".stripMargin,
- foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n"))
+ foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
ev.copy(
- s"""
+ code"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
- int $idxInVararg = 0;
+ int $idxVararg = 0;
$varargCounts
UTF8String[] $array = new UTF8String[$varargNum];
$varargBuilds
@@ -333,7 +253,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
val indexVal = ctx.freshName("index")
val indexMatched = ctx.freshName("eltIndexMatched")
- val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
+ val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal")
val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
s"""
@@ -350,10 +270,10 @@ case class Elt(children: Seq[Expression]) extends Expression {
expressions = assignInputValue,
funcName = "eltFunc",
extraArguments = ("int", indexVal) :: Nil,
- returnType = ctx.JAVA_BOOLEAN,
+ returnType = CodeGenerator.JAVA_BOOLEAN,
makeSplitFunction = body =>
s"""
- |${ctx.JAVA_BOOLEAN} $indexMatched = false;
+ |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|do {
| $body
|} while (false);
@@ -369,15 +289,15 @@ case class Elt(children: Seq[Expression]) extends Expression {
}.mkString)
ev.copy(
- s"""
+ code"""
|${index.code}
|final int $indexVal = ${index.value};
- |${ctx.JAVA_BOOLEAN} $indexMatched = false;
+ |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|$inputVal = null;
|do {
| $codes
|} while (false);
- |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
+ |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}
@@ -735,7 +655,7 @@ case class StringTrim(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -752,7 +672,7 @@ case class StringTrim(
} else {
${ev.value} = ${srcString.value}.trim(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -835,7 +755,7 @@ case class StringTrimLeft(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -852,7 +772,7 @@ case class StringTrimLeft(
} else {
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -937,7 +857,7 @@ case class StringTrimRight(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -954,7 +874,7 @@ case class StringTrimRight(
} else {
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -1105,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
val substrGen = substr.genCode(ctx)
val strGen = str.genCode(ctx)
val startGen = start.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
int ${ev.value} = 0;
boolean ${ev.isNull} = false;
${startGen.code}
@@ -1410,10 +1330,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val numArgLists = argListGen.length
val argListCode = argListGen.zipWithIndex.map { case(v, index) =>
val value =
- if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
+ if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
- s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
- s"new ${ctx.boxedType(v._1)}(${v._2.value})"
+ s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " +
+ s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.value}"
}
@@ -1431,10 +1351,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
- ev.copy(code = s"""
+ ev.copy(code = code"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
@@ -1504,26 +1424,6 @@ case class StringRepeat(str: Expression, times: Expression)
}
}
-/**
- * Returns the reversed given string.
- */
-@ExpressionDescription(
- usage = "_FUNC_(str) - Returns the reversed given string.",
- examples = """
- Examples:
- > SELECT _FUNC_('Spark SQL');
- LQS krapS
- """)
-case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
- override def convert(v: UTF8String): UTF8String = v.reverse()
-
- override def prettyName: String = "reverse"
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- defineCodeGen(ctx, ev, c => s"($c).reverse()")
- }
-}
-
/**
* Returns a string consisting of n spaces.
*/
@@ -2016,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression)
usage = """
_FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2`
decimal places. If `expr2` is 0, the result has no decimal point or fractional part.
+ `expr2` also accept a user specified format.
This is supposed to function like MySQL's FORMAT.
""",
examples = """
Examples:
> SELECT _FUNC_(12332.123456, 4);
12,332.1235
+ > SELECT _FUNC_(12332.123456, '##################.###');
+ 12332.123
""")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
@@ -2030,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression)
override def right: Expression = d
override def dataType: DataType = StringType
override def nullable: Boolean = true
- override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(NumericType, TypeCollection(IntegerType, StringType))
+
+ private val defaultFormat = "#,###,###,###,###,###,##0"
// Associated with the pattern, for the last d value, and we will update the
// pattern (DecimalFormat) once the new coming d value differ with the last one.
// This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after
// serialization (numberFormat has not been updated for dValue = 0).
@transient
- private var lastDValue: Option[Int] = None
+ private var lastDIntValue: Option[Int] = None
+
+ @transient
+ private var lastDStringValue: Option[String] = None
// A cached DecimalFormat, for performance concern, we will change it
// only if the d value changed.
@@ -2050,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression)
private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US))
override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
- val dValue = dObject.asInstanceOf[Int]
- if (dValue < 0) {
- return null
- }
-
- lastDValue match {
- case Some(last) if last == dValue =>
- // use the current pattern
- case _ =>
- // construct a new DecimalFormat only if a new dValue
- pattern.delete(0, pattern.length)
- pattern.append("#,###,###,###,###,###,##0")
-
- // decimal place
- if (dValue > 0) {
- pattern.append(".")
-
- var i = 0
- while (i < dValue) {
- i += 1
- pattern.append("0")
- }
+ right.dataType match {
+ case IntegerType =>
+ val dValue = dObject.asInstanceOf[Int]
+ if (dValue < 0) {
+ return null
}
- lastDValue = Some(dValue)
+ lastDIntValue match {
+ case Some(last) if last == dValue =>
+ // use the current pattern
+ case _ =>
+ // construct a new DecimalFormat only if a new dValue
+ pattern.delete(0, pattern.length)
+ pattern.append(defaultFormat)
+
+ // decimal place
+ if (dValue > 0) {
+ pattern.append(".")
+
+ var i = 0
+ while (i < dValue) {
+ i += 1
+ pattern.append("0")
+ }
+ }
+
+ lastDIntValue = Some(dValue)
- numberFormat.applyLocalizedPattern(pattern.toString)
+ numberFormat.applyLocalizedPattern(pattern.toString)
+ }
+ case StringType =>
+ val dValue = dObject.asInstanceOf[UTF8String].toString
+ lastDStringValue match {
+ case Some(last) if last == dValue =>
+ case _ =>
+ pattern.delete(0, pattern.length)
+ lastDStringValue = Some(dValue)
+ if (dValue.isEmpty) {
+ numberFormat.applyLocalizedPattern(defaultFormat)
+ } else {
+ numberFormat.applyLocalizedPattern(dValue)
+ }
+ }
}
x.dataType match {
@@ -2108,34 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression)
// SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.')
// as a decimal separator.
val usLocale = "US"
- val i = ctx.freshName("i")
- val dFormat = ctx.freshName("dFormat")
- val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;")
- val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")
- s"""
- if ($d >= 0) {
- $pattern.delete(0, $pattern.length());
- if ($d != $lastDValue) {
- $pattern.append("#,###,###,###,###,###,##0");
-
- if ($d > 0) {
- $pattern.append(".");
- for (int $i = 0; $i < $d; $i++) {
- $pattern.append("0");
+ right.dataType match {
+ case IntegerType =>
+ val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
+ val i = ctx.freshName("i")
+ val lastDValue =
+ ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
+ s"""
+ if ($d >= 0) {
+ $pattern.delete(0, $pattern.length());
+ if ($d != $lastDValue) {
+ $pattern.append("$defaultFormat");
+
+ if ($d > 0) {
+ $pattern.append(".");
+ for (int $i = 0; $i < $d; $i++) {
+ $pattern.append("0");
+ }
+ }
+ $lastDValue = $d;
+ $numberFormat.applyLocalizedPattern($pattern.toString());
+ }
+ ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+ } else {
+ ${ev.value} = null;
+ ${ev.isNull} = true;
+ }
+ """
+ case StringType =>
+ val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""")
+ val dValue = ctx.freshName("dValue")
+ s"""
+ String $dValue = $d.toString();
+ if (!$dValue.equals($lastDValue)) {
+ $lastDValue = $dValue;
+ if ($dValue.isEmpty()) {
+ $numberFormat.applyLocalizedPattern("$defaultFormat");
+ } else {
+ $numberFormat.applyLocalizedPattern($dValue);
}
}
- $lastDValue = $d;
- $numberFormat.applyLocalizedPattern($pattern.toString());
- }
- ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
- } else {
- ${ev.value} = null;
- ${ev.isNull} = true;
- }
- """
+ ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+ """
+ }
})
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 78895f1c2f6f5..f957aaa96e98c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp}
import org.apache.spark.sql.types._
/**
@@ -297,6 +297,37 @@ trait WindowFunction extends Expression {
def frame: WindowFrame = UnspecifiedFrame
}
+/**
+ * Case objects that describe whether a window function is a SQL window function or a Python
+ * user-defined window function.
+ */
+sealed trait WindowFunctionType
+
+object WindowFunctionType {
+ case object SQL extends WindowFunctionType
+ case object Python extends WindowFunctionType
+
+ def functionType(windowExpression: NamedExpression): WindowFunctionType = {
+ val t = windowExpression.collectFirst {
+ case _: WindowFunction | _: AggregateFunction => SQL
+ case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python
+ }
+
+ // Normally a window expression would either have a SQL window function, a SQL
+ // aggregate function or a python window UDF. However, sometimes the optimizer will replace
+ // the window function if the value of the window function can be predetermined.
+ // For example, for query:
+ //
+ // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a)
+ //
+ // The window function will be replaced by expression literal(0)
+ // To handle this case, if a window expression doesn't have a regular window function, we
+ // consider its type to be SQL as literal(0) is also a SQL expression.
+ t.getOrElse(SQL)
+ }
+}
+
+
/**
* An offset window function is a window function that returns the value of the input column offset
* by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with
@@ -342,7 +373,10 @@ abstract class OffsetWindowFunction
override lazy val frame: WindowFrame = {
val boundary = direction match {
case Ascending => offset
- case Descending => UnaryMinus(offset)
+ case Descending => UnaryMinus(offset) match {
+ case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
+ case o => o
+ }
}
SpecifiedWindowFrame(RowFrame, boundary, boundary)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index d0185562c9cfc..aacf1a44e2ad0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
""")
// scalastyle:on line.size.limit
case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract {
- override def prettyName: String = "xpath_float"
+ override def prettyName: String = "xpath_double"
override def dataType: DataType = DoubleType
override def nullSafeEval(xml: Any, path: Any): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
index b1672e7e2fca2..3e8e6db1dbd22 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala
@@ -18,10 +18,14 @@
package org.apache.spark.sql.catalyst.json
import java.io.{ByteArrayInputStream, InputStream, InputStreamReader}
+import java.nio.channels.Channels
+import java.nio.charset.Charset
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text
+import sun.nio.cs.StreamDecoder
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
private[sql] object CreateJacksonParser extends Serializable {
@@ -40,11 +44,51 @@ private[sql] object CreateJacksonParser extends Serializable {
}
def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
- val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength)
- jsonFactory.createParser(new InputStreamReader(bain, "UTF-8"))
+ jsonFactory.createParser(record.getBytes, 0, record.getLength)
+ }
+
+ // Jackson parsers can be ranked according to their performance:
+ // 1. Array based with actual encoding UTF-8 in the array. This is the fastest parser
+ // but it doesn't allow to set encoding explicitly. Actual encoding is detected automatically
+ // by checking leading bytes of the array.
+ // 2. InputStream based with actual encoding UTF-8 in the stream. Encoding is detected
+ // automatically by analyzing first bytes of the input stream.
+ // 3. Reader based parser. This is the slowest parser used here but it allows to create
+ // a reader with specific encoding.
+ // The method creates a reader for an array with given encoding and sets size of internal
+ // decoding buffer according to size of input array.
+ private def getStreamDecoder(enc: String, in: Array[Byte], length: Int): StreamDecoder = {
+ val bais = new ByteArrayInputStream(in, 0, length)
+ val byteChannel = Channels.newChannel(bais)
+ val decodingBufferSize = Math.min(length, 8192)
+ val decoder = Charset.forName(enc).newDecoder()
+
+ StreamDecoder.forDecoder(byteChannel, decoder, decodingBufferSize)
+ }
+
+ def text(enc: String, jsonFactory: JsonFactory, record: Text): JsonParser = {
+ val sd = getStreamDecoder(enc, record.getBytes, record.getLength)
+ jsonFactory.createParser(sd)
+ }
+
+ def inputStream(jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(is)
+ }
+
+ def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = {
+ jsonFactory.createParser(new InputStreamReader(is, enc))
}
- def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
- jsonFactory.createParser(new InputStreamReader(record, "UTF-8"))
+ def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val ba = row.getBinary(0)
+
+ jsonFactory.createParser(ba, 0, ba.length)
+ }
+
+ def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = {
+ val binary = row.getBinary(0)
+ val sd = getStreamDecoder(enc, binary, binary.length)
+
+ jsonFactory.createParser(sd)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 652412b34478a..c081772116f84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.json
+import java.nio.charset.{Charset, StandardCharsets}
import java.util.{Locale, TimeZone}
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
@@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
- @transient private val parameters: CaseInsensitiveMap[String],
+ @transient val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
@@ -72,6 +73,9 @@ private[sql] class JSONOptions(
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
+ // Whether to ignore column of all null values or empty array/struct during schema inference
+ val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
+
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
@@ -85,6 +89,46 @@ private[sql] class JSONOptions(
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
+ /**
+ * A string between two consecutive JSON records.
+ */
+ val lineSeparator: Option[String] = parameters.get("lineSep").map { sep =>
+ require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
+ sep
+ }
+
+ /**
+ * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE.
+ * If the encoding is not specified (None), it will be detected automatically
+ * when the multiLine option is set to `true`.
+ */
+ val encoding: Option[String] = parameters.get("encoding")
+ .orElse(parameters.get("charset")).map { enc =>
+ // The following encodings are not supported in per-line mode (multiline is false)
+ // because they cause some problems in reading files with BOM which is supposed to
+ // present in the files with such encodings. After splitting input files by lines,
+ // only the first lines will have the BOM which leads to impossibility for reading
+ // the rest lines. Besides of that, the lineSep option must have the BOM in such
+ // encodings which can never present between lines.
+ val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32"))
+ val isBlacklisted = blacklist.contains(Charset.forName(enc))
+ require(multiLine || !isBlacklisted,
+ s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled.
+ |Blacklist: ${blacklist.mkString(", ")}""".stripMargin)
+
+ val isLineSepRequired =
+ multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty
+
+ require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding")
+
+ enc
+ }
+
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(encoding.getOrElse("UTF-8"))
+ }
+ val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n")
+
/** Sets config options on a Jackson [[JsonFactory]]. */
def setJacksonOptions(factory: JsonFactory): Unit = {
factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index eb06e4f304f0a..9c413de752a8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.json
import java.io.Writer
+import java.nio.charset.StandardCharsets
import com.fasterxml.jackson.core._
@@ -74,6 +75,8 @@ private[sql] class JacksonGenerator(
private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+ private val lineSeparator: String = options.lineSeparatorInWrite
+
private def makeWriter(dataType: DataType): ValueWriter = dataType match {
case NullType =>
(row: SpecializedGetters, ordinal: Int) =>
@@ -251,5 +254,8 @@ private[sql] class JacksonGenerator(
mapType = dataType.asInstanceOf[MapType]))
}
- def writeLineEnding(): Unit = gen.writeRaw('\n')
+ def writeLineEnding(): Unit = {
+ // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8.
+ gen.writeRaw(lineSeparator)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index bd144c9575c72..c3a4ca8f64bf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.json
-import java.io.ByteArrayOutputStream
+import java.io.{ByteArrayOutputStream, CharConversionException}
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
@@ -36,7 +36,7 @@ import org.apache.spark.util.Utils
* Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
*/
class JacksonParser(
- schema: StructType,
+ schema: DataType,
val options: JSONOptions) extends Logging {
import JacksonUtils._
@@ -57,7 +57,14 @@ class JacksonParser(
* to a value according to a desired schema. This is a wrapper for the method
* `makeConverter()` to handle a row wrapped with an array.
*/
- private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
+ private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = {
+ dt match {
+ case st: StructType => makeStructRootConverter(st)
+ case mt: MapType => makeMapRootConverter(mt)
+ }
+ }
+
+ private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
val elementConverter = makeConverter(st)
val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
(parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) {
@@ -87,6 +94,13 @@ class JacksonParser(
}
}
+ private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = {
+ val fieldConverter = makeConverter(mt.valueType)
+ (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) {
+ case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter)))
+ }
+ }
+
/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
@@ -357,7 +371,18 @@ class JacksonParser(
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) =>
+ // JSON parser currently doesn't support partial results for corrupted records.
+ // For such records, all fields other than the field configured by
+ // `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => recordLiteral(record), () => None, e)
+ case e: CharConversionException if options.encoding.isEmpty =>
+ val msg =
+ """JSON parser cannot handle a character in its input.
+ |Specifying encoding as an input option explicitly might help to resolve the issue.
+ |""".stripMargin + e.getMessage
+ val wrappedCharException = new CharConversionException(msg)
+ wrappedCharException.initCause(e)
+ throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index be0009ec8c760..db7d6d3254bd2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -18,39 +18,39 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
-* push down operations into [[CreateNamedStructLike]].
-*/
-object SimplifyCreateStructOps extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transformExpressionsUp {
- // push down field extraction
+ * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions.
+ */
+object SimplifyExtractValueOps extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // One place where this optimization is invalid is an aggregation where the select
+ // list expression is a function of a grouping expression:
+ //
+ // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
+ //
+ // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
+ // optimization for Aggregates (although this misses some cases where the optimization
+ // can be made).
+ case a: Aggregate => a
+ case p => p.transformExpressionsUp {
+ // Remove redundant field extraction.
case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
createNamedStructLike.valExprs(ordinal)
- }
- }
-}
-/**
-* push down operations into [[CreateArray]].
-*/
-object SimplifyCreateArrayOps extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transformExpressionsUp {
- // push down field selection (array of structs)
- case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) =>
- // instead f selecting the field on the entire array,
- // select it from each member of the array.
- // pushing down the operation this way open other optimizations opportunities
- // (i.e. struct(...,x,...).x)
+ // Remove redundant array indexing.
+ case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
+ // Instead of selecting the field on the entire array, select it from each member
+ // of the array. Pushing down the operation this way may open other optimizations
+ // opportunities (i.e. struct(...,x,...).x)
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))))
- // push down item selection.
+
+ // Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) =>
- // instead of creating the array and then selecting one row,
- // remove array creation altgether.
+ // Instead of creating the array and then selecting one row, remove array creation
+ // altogether.
if (idx >= 0 && idx < elems.size) {
// valid index
elems(idx)
@@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] {
// out of bounds, mimic the runtime behavior and return null
Literal(null, ga.dataType)
}
- }
- }
-}
-
-/**
-* push down operations into [[CreateMap]].
-*/
-object SimplifyCreateMapOps extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transformExpressionsUp {
case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems)
}
}
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a28b6a0feb8f9..aa992def1ce6c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
EliminateSerialization,
RemoveRedundantAliases,
RemoveRedundantProject,
- SimplifyCreateStructOps,
- SimplifyCreateArrayOps,
- SimplifyCreateMapOps,
+ SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
@@ -140,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
operatorOptimizationBatch) :+
Batch("Join Reorder", Once,
CostBasedJoinReorder) :+
+ Batch("Remove Redundant Sorts", Once,
+ RemoveRedundantSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
@@ -155,7 +155,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
- RemoveRedundantProject)
+ RemoveRedundantProject) :+
+ Batch("UpdateAttributeReferences", Once,
+ UpdateNullabilityInAttributeReferences)
}
/**
@@ -619,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] {
/**
* Collapse Adjacent Window Expression.
* - If the partition specs and order specs are the same and the window expression are
- * independent, collapse into the parent.
+ * independent and are of the same window function type, collapse into the parent.
*/
object CollapseWindow extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
- if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty =>
+ if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
+ // This assumes Window contains the same type of window expressions. This is ensured
+ // by ExtractWindowFunctions.
+ WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
}
}
@@ -635,10 +640,11 @@ object CollapseWindow extends Rule[LogicalPlan] {
* constraints. These filters are currently inserted to the existing conditions in the Filter
* operators and on either side of Join operators.
*
- * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and
- * LeftSemi joins.
+ * Note: While this optimization is applicable to a lot of types of join, it primarily benefits
+ * Inner and LeftSemi joins.
*/
-object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
+object InferFiltersFromConstraints extends Rule[LogicalPlan]
+ with PredicateHelper with ConstraintHelper {
def apply(plan: LogicalPlan): LogicalPlan = {
if (SQLConf.get.constraintPropagationEnabled) {
@@ -659,21 +665,51 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
}
case join @ Join(left, right, joinType, conditionOpt) =>
- // Only consider constraints that can be pushed down completely to either the left or the
- // right child
- val constraints = join.constraints.filter { c =>
- c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
- }
- // Remove those constraints that are already enforced by either the left or the right child
- val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
- val newConditionOpt = conditionOpt match {
- case Some(condition) =>
- val newFilters = additionalConstraints -- splitConjunctivePredicates(condition)
- if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None
- case None =>
- additionalConstraints.reduceOption(And)
+ joinType match {
+ // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
+ // inner join, it just drops the right side in the final output.
+ case _: InnerLike | LeftSemi =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newLeft = inferNewFilter(left, allConstraints)
+ val newRight = inferNewFilter(right, allConstraints)
+ join.copy(left = newLeft, right = newRight)
+
+ // For right outer join, we can only infer additional filters for left side.
+ case RightOuter =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newLeft = inferNewFilter(left, allConstraints)
+ join.copy(left = newLeft)
+
+ // For left join, we can only infer additional filters for right side.
+ case LeftOuter | LeftAnti =>
+ val allConstraints = getAllConstraints(left, right, conditionOpt)
+ val newRight = inferNewFilter(right, allConstraints)
+ join.copy(right = newRight)
+
+ case _ => join
}
- if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join
+ }
+
+ private def getAllConstraints(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ conditionOpt: Option[Expression]): Set[Expression] = {
+ val baseConstraints = left.constraints.union(right.constraints)
+ .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
+ baseConstraints.union(inferAdditionalConstraints(baseConstraints))
+ }
+
+ private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
+ val newPredicates = constraints
+ .union(constructIsNotNullConstraints(constraints, plan.output))
+ .filter { c =>
+ c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
+ } -- plan.constraints
+ if (newPredicates.isEmpty) {
+ plan
+ } else {
+ Filter(newPredicates.reduce(And), plan)
+ }
}
}
@@ -733,6 +769,33 @@ object EliminateSorts extends Rule[LogicalPlan] {
}
}
+/**
+ * Removes redundant Sort operation. This can happen:
+ * 1) if the child is already sorted
+ * 2) if there is another Sort operator separated by 0...n Project/Filter operators
+ */
+object RemoveRedundantSorts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
+ child
+ case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
+ }
+
+ def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
+ case Sort(_, _, child) => recursiveRemoveSort(child)
+ case other if canEliminateSort(other) =>
+ other.withNewChildren(other.children.map(recursiveRemoveSort))
+ case _ => plan
+ }
+
+ def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
+ case p: Project => p.projectList.forall(_.deterministic)
+ case f: Filter => f.condition.deterministic
+ case _: ResolvedHint => true
+ case _ => false
+ }
+}
+
/**
* Removes filters that can be evaluated trivially. This can be done through the following ways:
* 1) by eliding the filter for cases where it will always evaluate to `true`.
@@ -1122,12 +1185,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper {
case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _)
if isCartesianProduct(j) =>
throw new AnalysisException(
- s"""Detected cartesian product for ${j.joinType.sql} join between logical plans
+ s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans
|${left.treeString(false).trim}
|and
|${right.treeString(false).trim}
|Join condition is missing or trivial.
- |Use the CROSS JOIN syntax to allow cartesian products between these relations."""
+ |Either: use the CROSS JOIN syntax to allow cartesian products between these
+ |relations, or: enable implicit cartesian products by setting the configuration
+ |variable spark.sql.crossJoin.enabled=true"""
.stripMargin)
}
}
@@ -1311,3 +1376,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Updates nullability in [[AttributeReference]]s if nullability is different between
+ * non-leaf plan's expressions and the children output.
+ */
+object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case p if !p.isInstanceOf[LeafNode] =>
+ val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable })
+ p transformExpressions {
+ case ar: AttributeReference if nullabilityMap.contains(ar) =>
+ ar.withNullability(nullabilityMap(ar))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index a6e5aa6daca65..c3fdb924243df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
/**
* Collapse plans consisting empty local relations generated by [[PruneFilters]].
@@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._
* - Aggregate with all empty children and at least one grouping expression.
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
-object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
+object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport {
private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
@@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
// Construct a project list from plan's output, while the value is always NULL.
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
- plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) }
+ plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }
+
+ override def conf: SQLConf = SQLConf.get
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala
index 1f20b7661489e..2aa762e2595ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala
@@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper {
stats.rowCount match {
case Some(rowCount) if rowCount >= 0 =>
if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) {
- val colStats = stats.attributeStats.get(col)
- if (colStats.get.nullCount > 0) {
+ val colStats = stats.attributeStats.get(col).get
+ if (!colStats.hasCountStats || colStats.nullCount.get > 0) {
false
} else {
- val distinctCount = colStats.get.distinctCount
+ val distinctCount = colStats.distinctCount.get
val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d)
// ndvMaxErr adjusted based on TPCDS 1TB data results
relDiff <= conf.ndvMaxError * 2
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 1c0b7bd806801..1d363b8146e3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 709db6d8bec7d..de89e17e51f1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// (a1,a2,...) = (b1,b2,...)
// to
// (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ...
- val joinConds = splitConjunctivePredicates(joinCond.get)
+ val baseJoinConds = splitConjunctivePredicates(joinCond.get)
+ val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c)))
// After that, add back the correlated join predicate(s) in the subquery
// Example:
// SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1)
// will have the final conditions in the LEFT ANTI as
- // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
- val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And)
+ // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
+ val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
// Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs)))
+ dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index bdc357d54a878..383ebde3229d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -503,7 +503,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val join = right.optionalMap(left)(Join(_, _, Inner, None))
withJoinRelations(join, relation)
}
- ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
+ withPivot(ctx.pivotClause, from)
+ } else {
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
}
/**
@@ -614,6 +621,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
plan
}
+ /**
+ * Add a [[Pivot]] to a logical plan.
+ */
+ private def withPivot(
+ ctx: PivotClauseContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val aggregates = Option(ctx.aggregates).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+ val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText)
+ val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply)
+ Pivot(None, pivotColumn, pivotValues, aggregates, query)
+ }
+
/**
* Add a [[Generate]] (Lateral View) to a logical plan.
*/
@@ -1185,6 +1206,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
new StringLocate(expression(ctx.substr), expression(ctx.str))
}
+ /**
+ * Create a Extract expression.
+ */
+ override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) {
+ ctx.field.getText.toUpperCase(Locale.ROOT) match {
+ case "YEAR" =>
+ Year(expression(ctx.source))
+ case "QUARTER" =>
+ Quarter(expression(ctx.source))
+ case "MONTH" =>
+ Month(expression(ctx.source))
+ case "WEEK" =>
+ WeekOfYear(expression(ctx.source))
+ case "DAY" =>
+ DayOfMonth(expression(ctx.source))
+ case "DAYOFWEEK" =>
+ DayOfWeek(expression(ctx.source))
+ case "HOUR" =>
+ Hour(expression(ctx.source))
+ case "MINUTE" =>
+ Minute(expression(ctx.source))
+ case "SECOND" =>
+ Second(expression(ctx.source))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
/**
* Create a (windowed) Function expression.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 626f905707191..84be677e438a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
@@ -215,7 +216,7 @@ object PhysicalAggregation {
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
- if PythonUDF.isGroupAggPandasUDF(udf) &&
+ if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}
@@ -245,7 +246,7 @@ object PhysicalAggregation {
equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
.getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
// Similar to AggregateExpression
- case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
+ case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
.getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
case expression =>
@@ -268,3 +269,40 @@ object PhysicalAggregation {
case _ => None
}
}
+
+/**
+ * An extractor used when planning physical execution of a window. This extractor outputs
+ * the window function type of the logical window.
+ *
+ * The input logical window must contain same type of window functions, which is ensured by
+ * the rule ExtractWindowExpressions in the analyzer.
+ */
+object PhysicalWindow {
+ // windowFunctionType, windowExpression, partitionSpec, orderSpec, child
+ private type ReturnType =
+ (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan)
+
+ def unapply(a: Any): Option[ReturnType] = a match {
+ case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) =>
+
+ // The window expression should not be empty here, otherwise it's a bug.
+ if (windowExpressions.isEmpty) {
+ throw new AnalysisException(s"Window expression is empty in $expr")
+ }
+
+ val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType)
+ .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) =>
+ if (t1 != t2) {
+ // We shouldn't have different window function type here, otherwise it's a bug.
+ throw new AnalysisException(
+ s"Found different window function type in $windowExpressions")
+ } else {
+ t1
+ }
+ }
+
+ Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child))
+
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index ddf2cbf2ab911..e431c9523a9da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
@@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
var changed = false
@inline def transformExpression(e: Expression): Expression = {
- val newE = f(e)
+ val newE = CurrentOrigin.withOrigin(e.origin) {
+ f(e)
+ }
if (newE.fastEquals(e)) {
e
} else {
@@ -117,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
case Some(value) => Some(recursiveTransform(value))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
+ case stream: Stream[_] => stream.map(recursiveTransform).force
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
case null => null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index d73d7e73f28d5..8c4828a4cef23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
@@ -43,10 +44,17 @@ object LocalRelation {
}
}
-case class LocalRelation(output: Seq[Attribute],
- data: Seq[InternalRow] = Nil,
- // Indicates whether this relation has data from a streaming source.
- override val isStreaming: Boolean = false)
+/**
+ * Logical plan node for scanning data from a local collection.
+ *
+ * @param data The local collection holding the data. It doesn't need to be sent to executors
+ * and then doesn't need to be serializable.
+ */
+case class LocalRelation(
+ output: Seq[Attribute],
+ data: Seq[InternalRow] = Nil,
+ // Indicates whether this relation has data from a streaming source.
+ override val isStreaming: Boolean = false)
extends LeafNode with analysis.MultiInstanceRelation {
// A local relation must have resolved output.
@@ -70,7 +78,7 @@ case class LocalRelation(output: Seq[Attribute],
}
override def computeStats(): Statistics =
- Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
+ Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)
def toSQL(inlineTableName: String): String = {
require(data.nonEmpty)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c8ccd9bd03994..c486ad700f362 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -78,7 +78,7 @@ abstract class LogicalPlan
schema.map { field =>
resolve(field.name :: Nil, resolver).map {
case a: AttributeReference => a
- case other => sys.error(s"can not handle nested schema yet... plan $this")
+ case _ => sys.error(s"can not handle nested schema yet... plan $this")
}.getOrElse {
throw new AnalysisException(
s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]")
@@ -86,6 +86,10 @@ abstract class LogicalPlan
}
}
+ private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output))
+
+ private[this] lazy val outputAttributes = AttributeSeq(output)
+
/**
* Optionally resolves the given strings to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
@@ -94,7 +98,7 @@ abstract class LogicalPlan
def resolveChildren(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
- resolve(nameParts, children.flatMap(_.output), resolver)
+ childAttributes.resolve(nameParts, resolver)
/**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
@@ -104,7 +108,7 @@ abstract class LogicalPlan
def resolve(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
- resolve(nameParts, output, resolver)
+ outputAttributes.resolve(nameParts, resolver)
/**
* Given an attribute name, split it to name parts by dot, but
@@ -114,111 +118,18 @@ abstract class LogicalPlan
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
- resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver)
- }
-
- /**
- * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
- *
- * This assumes `name` has multiple parts, where the 1st part is a qualifier
- * (i.e. table name, alias, or subquery alias).
- * See the comment above `candidates` variable in resolve() for semantics the returned data.
- */
- private def resolveAsTableColumn(
- nameParts: Seq[String],
- resolver: Resolver,
- attribute: Attribute): Option[(Attribute, List[String])] = {
- assert(nameParts.length > 1)
- if (attribute.qualifier.exists(resolver(_, nameParts.head))) {
- // At least one qualifier matches. See if remaining parts match.
- val remainingParts = nameParts.tail
- resolveAsColumn(remainingParts, resolver, attribute)
- } else {
- None
- }
+ outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver)
}
/**
- * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
- *
- * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier.
- * See the comment above `candidates` variable in resolve() for semantics the returned data.
+ * Refreshes (or invalidates) any metadata/data cached in the plan recursively.
*/
- private def resolveAsColumn(
- nameParts: Seq[String],
- resolver: Resolver,
- attribute: Attribute): Option[(Attribute, List[String])] = {
- if (resolver(attribute.name, nameParts.head)) {
- Option((attribute.withName(nameParts.head), nameParts.tail.toList))
- } else {
- None
- }
- }
-
- /** Performs attribute resolution given a name and a sequence of possible attributes. */
- protected def resolve(
- nameParts: Seq[String],
- input: Seq[Attribute],
- resolver: Resolver): Option[NamedExpression] = {
-
- // A sequence of possible candidate matches.
- // Each candidate is a tuple. The first element is a resolved attribute, followed by a list
- // of parts that are to be resolved.
- // For example, consider an example where "a" is the table name, "b" is the column name,
- // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
- // and the second element will be List("c").
- var candidates: Seq[(Attribute, List[String])] = {
- // If the name has 2 or more parts, try to resolve it as `table.column` first.
- if (nameParts.length > 1) {
- input.flatMap { option =>
- resolveAsTableColumn(nameParts, resolver, option)
- }
- } else {
- Seq.empty
- }
- }
-
- // If none of attributes match `table.column` pattern, we try to resolve it as a column.
- if (candidates.isEmpty) {
- candidates = input.flatMap { candidate =>
- resolveAsColumn(nameParts, resolver, candidate)
- }
- }
-
- def name = UnresolvedAttribute(nameParts).name
-
- candidates.distinct match {
- // One match, no nested fields, use it.
- case Seq((a, Nil)) => Some(a)
-
- // One match, but we also need to extract the requested nested field.
- case Seq((a, nestedFields)) =>
- // The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and aliased it with the last part of the name.
- // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
- // expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
- ExtractValue(expr, Literal(fieldName), resolver))
- Some(Alias(fieldExprs, nestedFields.last)())
-
- // No matches.
- case Seq() =>
- logTrace(s"Could not find $name in ${input.mkString(", ")}")
- None
-
- // More than one match.
- case ambiguousReferences =>
- val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
- throw new AnalysisException(
- s"Reference '$name' is ambiguous, could be: $referenceNames.")
- }
- }
+ def refresh(): Unit = children.foreach(_.refresh())
/**
- * Refreshes (or invalidates) any metadata/data cached in the plan recursively.
+ * Returns the output ordering that this plan generates.
*/
- def refresh(): Unit = children.foreach(_.refresh())
+ def outputOrdering: Seq[SortOrder] = Nil
}
/**
@@ -274,3 +185,7 @@ abstract class BinaryNode extends LogicalPlan {
override final def children: Seq[LogicalPlan] = Seq(left, right)
}
+
+abstract class OrderPreservingUnaryNode extends UnaryNode {
+ override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
index 5c7b8e5b97883..cc352c59dff80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
-trait QueryPlanConstraints { self: LogicalPlan =>
+trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
@@ -32,7 +32,7 @@ trait QueryPlanConstraints { self: LogicalPlan =>
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
- .union(constructIsNotNullConstraints(validConstraints))
+ .union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
@@ -51,13 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* See [[Canonicalize]] for more details.
*/
protected def validConstraints: Set[Expression] = Set.empty
+}
+
+trait ConstraintHelper {
+
+ /**
+ * Infers an additional set of constraints from a given set of equality constraints.
+ * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
+ * additional constraint of the form `b = 5`.
+ */
+ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
+ var inferredConstraints = Set.empty[Expression]
+ constraints.foreach {
+ case eq @ EqualTo(l: Attribute, r: Attribute) =>
+ val candidateConstraints = constraints - eq
+ inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
+ inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
+ case _ => // No inference
+ }
+ inferredConstraints -- constraints
+ }
+
+ private def replaceConstraints(
+ constraints: Set[Expression],
+ source: Expression,
+ destination: Attribute): Set[Expression] = constraints.map(_ transform {
+ case e: Expression if e.semanticEquals(source) => destination
+ })
/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
* returns a constraint of the form `isNotNull(a)`
*/
- private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
+ def constructIsNotNullConstraints(
+ constraints: Set[Expression],
+ output: Seq[Attribute]): Set[Expression] = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
@@ -93,28 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan =>
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}
-
- /**
- * Infers an additional set of constraints from a given set of equality constraints.
- * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
- * additional constraint of the form `b = 5`.
- */
- private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
- var inferredConstraints = Set.empty[Expression]
- constraints.foreach {
- case eq @ EqualTo(l: Attribute, r: Attribute) =>
- val candidateConstraints = constraints - eq
- inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
- inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
- case _ => // No inference
- }
- inferredConstraints -- constraints
- }
-
- private def replaceConstraints(
- constraints: Set[Expression],
- source: Expression,
- destination: Attribute): Set[Expression] = constraints.map(_ transform {
- case e: Expression if e.semanticEquals(source) => destination
- })
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 96b199d7f20b0..b3a48860aa63b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
@@ -79,11 +80,10 @@ case class Statistics(
/**
* Statistics collected for a column.
*
- * 1. Supported data types are defined in `ColumnStat.supportsType`.
- * 2. The JVM data type stored in min/max is the internal data type for the corresponding
+ * 1. The JVM data type stored in min/max is the internal data type for the corresponding
* Catalyst data type. For example, the internal type of DateType is Int, and that the internal
* type of TimestampType is Long.
- * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms
* (sketches) might have been used, and the data collected can also be stale.
*
* @param distinctCount number of distinct values
@@ -95,240 +95,32 @@ case class Statistics(
* @param histogram histogram of the values
*/
case class ColumnStat(
- distinctCount: BigInt,
- min: Option[Any],
- max: Option[Any],
- nullCount: BigInt,
- avgLen: Long,
- maxLen: Long,
+ distinctCount: Option[BigInt] = None,
+ min: Option[Any] = None,
+ max: Option[Any] = None,
+ nullCount: Option[BigInt] = None,
+ avgLen: Option[Long] = None,
+ maxLen: Option[Long] = None,
histogram: Option[Histogram] = None) {
- // We currently don't store min/max for binary/string type. This can change in the future and
- // then we need to remove this require.
- require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String]))
- require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String]))
-
- /**
- * Returns a map from string to string that can be used to serialize the column stats.
- * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
- * representation for the value. min/max values are converted to the external data type. For
- * example, for DateType we store java.sql.Date, and for TimestampType we store
- * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]].
- *
- * As part of the protocol, the returned map always contains a key called "version".
- * In the case min/max values are null (None), they won't appear in the map.
- */
- def toMap(colName: String, dataType: DataType): Map[String, String] = {
- val map = new scala.collection.mutable.HashMap[String, String]
- map.put(ColumnStat.KEY_VERSION, "1")
- map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
- map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
- map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
- map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
- min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) }
- max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) }
- histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) }
- map.toMap
- }
-
- /**
- * Converts the given value from Catalyst data type to string representation of external
- * data type.
- */
- private def toExternalString(v: Any, colName: String, dataType: DataType): String = {
- val externalValue = dataType match {
- case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
- case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
- case BooleanType | _: IntegralType | FloatType | DoubleType => v
- case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
- // This version of Spark does not use min/max for binary/string types so we ignore it.
- case _ =>
- throw new AnalysisException("Column statistics deserialization is not supported for " +
- s"column $colName of data type: $dataType.")
- }
- externalValue.toString
- }
-
-}
+ // Are distinctCount and nullCount statistics defined?
+ val hasCountStats = distinctCount.isDefined && nullCount.isDefined
+ // Are min and max statistics defined?
+ val hasMinMaxStats = min.isDefined && max.isDefined
-object ColumnStat extends Logging {
-
- // List of string keys used to serialize ColumnStat
- val KEY_VERSION = "version"
- private val KEY_DISTINCT_COUNT = "distinctCount"
- private val KEY_MIN_VALUE = "min"
- private val KEY_MAX_VALUE = "max"
- private val KEY_NULL_COUNT = "nullCount"
- private val KEY_AVG_LEN = "avgLen"
- private val KEY_MAX_LEN = "maxLen"
- private val KEY_HISTOGRAM = "histogram"
-
- /** Returns true iff the we support gathering column statistics on column of the given type. */
- def supportsType(dataType: DataType): Boolean = dataType match {
- case _: IntegralType => true
- case _: DecimalType => true
- case DoubleType | FloatType => true
- case BooleanType => true
- case DateType => true
- case TimestampType => true
- case BinaryType | StringType => true
- case _ => false
- }
-
- /** Returns true iff the we support gathering histogram on column of the given type. */
- def supportsHistogram(dataType: DataType): Boolean = dataType match {
- case _: IntegralType => true
- case _: DecimalType => true
- case DoubleType | FloatType => true
- case DateType => true
- case TimestampType => true
- case _ => false
- }
-
- /**
- * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
- * from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
- */
- def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = {
- try {
- Some(ColumnStat(
- distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
- // Note that flatMap(Option.apply) turns Option(null) into None.
- min = map.get(KEY_MIN_VALUE)
- .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
- max = map.get(KEY_MAX_VALUE)
- .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
- nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
- avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
- maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong,
- histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize)
- ))
- } catch {
- case NonFatal(e) =>
- logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e)
- None
- }
- }
-
- /**
- * Converts from string representation of external data type to the corresponding Catalyst data
- * type.
- */
- private def fromExternalString(s: String, name: String, dataType: DataType): Any = {
- dataType match {
- case BooleanType => s.toBoolean
- case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
- case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
- case ByteType => s.toByte
- case ShortType => s.toShort
- case IntegerType => s.toInt
- case LongType => s.toLong
- case FloatType => s.toFloat
- case DoubleType => s.toDouble
- case _: DecimalType => Decimal(s)
- // This version of Spark does not use min/max for binary/string types so we ignore it.
- case BinaryType | StringType => null
- case _ =>
- throw new AnalysisException("Column statistics deserialization is not supported for " +
- s"column $name of data type: $dataType.")
- }
- }
-
- /**
- * Constructs an expression to compute column statistics for a given column.
- *
- * The expression should create a single struct column with the following schema:
- * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long,
- * distinctCountsForIntervals: Array[Long]
- *
- * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and
- * as a result should stay in sync with it.
- */
- def statExprs(
- col: Attribute,
- conf: SQLConf,
- colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = {
- def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
- expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
- })
- val one = Literal(1, LongType)
-
- // the approximate ndv (num distinct value) should never be larger than the number of rows
- val numNonNulls = if (col.nullable) Count(col) else Count(one)
- val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
- val numNulls = Subtract(Count(one), numNonNulls)
- val defaultSize = Literal(col.dataType.defaultSize, LongType)
- val nullArray = Literal(null, ArrayType(LongType))
-
- def fixedLenTypeStruct: CreateNamedStruct = {
- val genHistogram =
- ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col)
- val intervalNdvsExpr = if (genHistogram) {
- ApproxCountDistinctForIntervals(col,
- Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError)
- } else {
- nullArray
- }
- // For fixed width types, avg size should be the same as max size.
- struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls,
- defaultSize, defaultSize, intervalNdvsExpr)
- }
-
- col.dataType match {
- case _: IntegralType => fixedLenTypeStruct
- case _: DecimalType => fixedLenTypeStruct
- case DoubleType | FloatType => fixedLenTypeStruct
- case BooleanType => fixedLenTypeStruct
- case DateType => fixedLenTypeStruct
- case TimestampType => fixedLenTypeStruct
- case BinaryType | StringType =>
- // For string and binary type, we don't compute min, max or histogram
- val nullLit = Literal(null, col.dataType)
- struct(
- ndv, nullLit, nullLit, numNulls,
- // Set avg/max size to default size if all the values are null or there is no value.
- Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)),
- Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)),
- nullArray)
- case _ =>
- throw new AnalysisException("Analyzing column statistics is not supported for column " +
- s"${col.name} of data type: ${col.dataType}.")
- }
- }
-
- /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */
- def rowToColumnStat(
- row: InternalRow,
- attr: Attribute,
- rowCount: Long,
- percentiles: Option[ArrayData]): ColumnStat = {
- // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins.
- val cs = ColumnStat(
- distinctCount = BigInt(row.getLong(0)),
- // for string/binary min/max, get should return null
- min = Option(row.get(1, attr.dataType)),
- max = Option(row.get(2, attr.dataType)),
- nullCount = BigInt(row.getLong(3)),
- avgLen = row.getLong(4),
- maxLen = row.getLong(5)
- )
- if (row.isNullAt(6)) {
- cs
- } else {
- val ndvs = row.getArray(6).toLongArray()
- assert(percentiles.get.numElements() == ndvs.length + 1)
- val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble)
- // Construct equi-height histogram
- val bins = ndvs.zipWithIndex.map { case (ndv, i) =>
- HistogramBin(endpoints(i), endpoints(i + 1), ndv)
- }
- val nonNullRows = rowCount - cs.nullCount
- val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins)
- cs.copy(histogram = Some(histogram))
- }
- }
+ // Are avgLen and maxLen statistics defined?
+ val hasLenStats = avgLen.isDefined && maxLen.isDefined
+ def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat =
+ CatalogColumnStat(
+ distinctCount = distinctCount,
+ min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)),
+ max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)),
+ nullCount = nullCount,
+ avgLen = avgLen,
+ maxLen = maxLen,
+ histogram = histogram)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a4fca790dd086..3bf32ef7884e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
* This node is inserted at the top of a subquery when it is optimized. This makes sure we can
* recognize a subquery as such, and it allows us to write subquery aware transformations.
*/
-case class Subquery(child: LogicalPlan) extends UnaryNode {
+case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
}
-case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
+case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
+ extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows
@@ -125,7 +126,7 @@ case class Generate(
}
case class Filter(condition: Expression, child: LogicalPlan)
- extends UnaryNode with PredicateHelper {
+ extends OrderPreservingUnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
@@ -469,6 +470,7 @@ case class Sort(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
+ override def outputOrdering: Seq[SortOrder] = order
}
/** Factory for constructing new `Range` nodes. */
@@ -522,6 +524,15 @@ case class Range(
override def computeStats(): Statistics = {
Statistics(sizeInBytes = LongType.defaultSize * numElements)
}
+
+ override def outputOrdering: Seq[SortOrder] = {
+ val order = if (step > 0) {
+ Ascending
+ } else {
+ Descending
+ }
+ output.map(a => SortOrder(a, order))
+ }
}
case class Aggregate(
@@ -675,17 +686,34 @@ case class GroupingSets(
override lazy val resolved: Boolean = false
}
+/**
+ * A constructor for creating a pivot, which will later be converted to a [[Project]]
+ * or an [[Aggregate]] during the query analysis.
+ *
+ * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming
+ * from SQL, in which group by expressions are not explicitly specified.
+ * @param pivotColumn The pivot column.
+ * @param pivotValues A sequence of values for the pivot column.
+ * @param aggregates The aggregation expressions, each with or without an alias.
+ * @param child Child operator
+ */
case class Pivot(
- groupByExprs: Seq[NamedExpression],
+ groupByExprsOpt: Option[Seq[NamedExpression]],
pivotColumn: Expression,
pivotValues: Seq[Literal],
aggregates: Seq[Expression],
child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
- case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
- case _ => pivotValues.flatMap{ value =>
- aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
+ override lazy val resolved = false // Pivot will be replaced after being resolved.
+ override def output: Seq[Attribute] = {
+ val pivotAgg = aggregates match {
+ case agg :: Nil =>
+ pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
+ case _ =>
+ pivotValues.flatMap { value =>
+ aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)())
+ }
}
+ groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
}
}
@@ -728,7 +756,7 @@ object Limit {
*
* See [[Limit]] for more information.
*/
-case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = {
limitExpr match {
@@ -744,7 +772,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
*
* See [[Limit]] for more information.
*/
-case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRowsPerPartition: Option[Long] = {
@@ -764,7 +792,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
case class SubqueryAlias(
alias: String,
child: LogicalPlan)
- extends UnaryNode {
+ extends OrderPreservingUnaryNode {
override def doCanonicalize(): LogicalPlan = child.canonicalized
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index c41fac4015ec0..111c594a53e52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -32,13 +32,15 @@ object AggregateEstimation {
val childStats = agg.child.stats
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
- e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
+ e.isInstanceOf[Attribute] &&
+ childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats)
}
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
- (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
+ (res, expr) => res *
+ childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get)
outputRows = if (agg.groupingExpressions.isEmpty) {
// If there's no group-by columns, the output is a single row containing values of aggregate
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index d793f77413d18..211a2a0717371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{DecimalType, _}
-
object EstimationUtils {
/** Check if each plan has rowCount in its statistics. */
@@ -38,9 +37,18 @@ object EstimationUtils {
}
}
+ /** Check if each attribute has column stat containing distinct and null counts
+ * in the corresponding statistic. */
+ def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = {
+ statsAndAttr.forall { case (stats, attr) =>
+ stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false)
+ }
+ }
+
+ /** Statistics for a Column containing only NULLs. */
def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = {
- ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount,
- avgLen = dataType.defaultSize, maxLen = dataType.defaultSize)
+ ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount),
+ avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize))
}
/**
@@ -63,29 +71,33 @@ object EstimationUtils {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}
- def getOutputSize(
+ def getSizePerRow(
attributes: Seq[Attribute],
- outputRowCount: BigInt,
attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
- val sizePerRow = 8 + attributes.map { attr =>
- if (attrStats.contains(attr)) {
+ 8 + attributes.map { attr =>
+ if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
attr.dataType match {
case StringType =>
// UTF8String: base + offset + numBytes
- attrStats(attr).avgLen + 8 + 4
+ attrStats(attr).avgLen.get + 8 + 4
case _ =>
- attrStats(attr).avgLen
+ attrStats(attr).avgLen.get
}
} else {
attr.dataType.defaultSize
}
}.sum
+ }
+ def getOutputSize(
+ attributes: Seq[Attribute],
+ outputRowCount: BigInt,
+ attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
// (simple computation of statistics returns product of children).
- if (outputRowCount > 0) outputRowCount * sizePerRow else 1
+ if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 4cc32de2d32d7..5a3eeefaedb18 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
attr: Attribute,
isNull: Boolean,
update: Boolean): Option[Double] = {
- if (!colStatsMap.contains(attr)) {
+ if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) {
logDebug("[CBO] No statistics for " + attr)
return None
}
@@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging {
val nullPercent: Double = if (rowCountValue == 0) {
0
} else {
- (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble
+ (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble
}
if (update) {
val newStats = if (isNull) {
- colStat.copy(distinctCount = 0, min = None, max = None)
+ colStat.copy(distinctCount = Some(0), min = None, max = None)
} else {
- colStat.copy(nullCount = 0)
+ colStat.copy(nullCount = Some(0))
}
colStatsMap.update(attr, newStats)
}
@@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging {
// value.
val newStats = attr.dataType match {
case StringType | BinaryType =>
- colStat.copy(distinctCount = 1, nullCount = 0)
+ colStat.copy(distinctCount = Some(1), nullCount = Some(0))
case _ =>
- colStat.copy(distinctCount = 1, min = Some(literal.value),
- max = Some(literal.value), nullCount = 0)
+ colStat.copy(distinctCount = Some(1), min = Some(literal.value),
+ max = Some(literal.value), nullCount = Some(0))
}
colStatsMap.update(attr, newStats)
}
if (colStat.histogram.isEmpty) {
- // returns 1/ndv if there is no histogram
- Some(1.0 / colStat.distinctCount.toDouble)
+ if (!colStat.distinctCount.isEmpty) {
+ // returns 1/ndv if there is no histogram
+ Some(1.0 / colStat.distinctCount.get.toDouble)
+ } else {
+ None
+ }
} else {
Some(computeEqualityPossibilityByHistogram(literal, colStat))
}
@@ -378,19 +382,23 @@ case class FilterEstimation(plan: Filter) extends Logging {
attr: Attribute,
hSet: Set[Any],
update: Boolean): Option[Double] = {
- if (!colStatsMap.contains(attr)) {
+ if (!colStatsMap.hasDistinctCount(attr)) {
logDebug("[CBO] No statistics for " + attr)
return None
}
val colStat = colStatsMap(attr)
- val ndv = colStat.distinctCount
+ val ndv = colStat.distinctCount.get
val dataType = attr.dataType
var newNdv = ndv
// use [min, max] to filter the original hSet
dataType match {
case _: NumericType | BooleanType | DateType | TimestampType =>
+ if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) {
+ return Some(0.0)
+ }
+
val statsInterval =
ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval]
val validQuerySet = hSet.filter { v =>
@@ -407,16 +415,20 @@ case class FilterEstimation(plan: Filter) extends Logging {
// 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
newNdv = ndv.min(BigInt(validQuerySet.size))
if (update) {
- val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
- max = Some(newMax), nullCount = 0)
+ val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin),
+ max = Some(newMax), nullCount = Some(0))
colStatsMap.update(attr, newStats)
}
// We assume the whole set since there is no min/max information for String/Binary type
case StringType | BinaryType =>
+ if (ndv.toDouble == 0) {
+ return Some(0.0)
+ }
+
newNdv = ndv.min(BigInt(hSet.size))
if (update) {
- val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0)
+ val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0))
colStatsMap.update(attr, newStats)
}
}
@@ -443,12 +455,17 @@ case class FilterEstimation(plan: Filter) extends Logging {
literal: Literal,
update: Boolean): Option[Double] = {
+ if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
+ return None
+ }
+
val colStat = colStatsMap(attr)
val statsInterval =
ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval]
val max = statsInterval.max
val min = statsInterval.min
- val ndv = colStat.distinctCount.toDouble
+ val ndv = colStat.distinctCount.get.toDouble
// determine the overlapping degree between predicate interval and column's interval
val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType)
@@ -520,8 +537,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
newMax = newValue
}
- val newStats = colStat.copy(distinctCount = ceil(ndv * percent),
- min = newMin, max = newMax, nullCount = 0)
+ val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)),
+ min = newMin, max = newMax, nullCount = Some(0))
colStatsMap.update(attr, newStats)
}
@@ -637,11 +654,11 @@ case class FilterEstimation(plan: Filter) extends Logging {
attrRight: Attribute,
update: Boolean): Option[Double] = {
- if (!colStatsMap.contains(attrLeft)) {
+ if (!colStatsMap.hasCountStats(attrLeft)) {
logDebug("[CBO] No statistics for " + attrLeft)
return None
}
- if (!colStatsMap.contains(attrRight)) {
+ if (!colStatsMap.hasCountStats(attrRight)) {
logDebug("[CBO] No statistics for " + attrRight)
return None
}
@@ -668,7 +685,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
val minRight = statsIntervalRight.min
// determine the overlapping degree between predicate interval and column's interval
- val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
+ val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0)
val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
// Left < Right or Left <= Right
// - no overlap:
@@ -707,14 +724,14 @@ case class FilterEstimation(plan: Filter) extends Logging {
case _: EqualTo =>
((maxLeft < minRight) || (maxRight < minLeft),
(minLeft == minRight) && (maxLeft == maxRight) && allNotNull
- && (colStatLeft.distinctCount == colStatRight.distinctCount)
+ && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get)
)
case _: EqualNullSafe =>
// For null-safe equality, we use a very restrictive condition to evaluate its overlap.
// If null values exists, we set it to partial overlap.
(((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull,
(minLeft == minRight) && (maxLeft == maxRight) && allNotNull
- && (colStatLeft.distinctCount == colStatRight.distinctCount)
+ && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get)
)
}
@@ -731,9 +748,9 @@ case class FilterEstimation(plan: Filter) extends Logging {
if (update) {
// Need to adjust new min/max after the filter condition is applied
- val ndvLeft = BigDecimal(colStatLeft.distinctCount)
+ val ndvLeft = BigDecimal(colStatLeft.distinctCount.get)
val newNdvLeft = ceil(ndvLeft * percent)
- val ndvRight = BigDecimal(colStatRight.distinctCount)
+ val ndvRight = BigDecimal(colStatRight.distinctCount.get)
val newNdvRight = ceil(ndvRight * percent)
var newMaxLeft = colStatLeft.max
@@ -817,10 +834,10 @@ case class FilterEstimation(plan: Filter) extends Logging {
}
}
- val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft,
+ val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft,
max = newMaxLeft)
colStatsMap(attrLeft) = newStatsLeft
- val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight,
+ val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight,
max = newMaxRight)
colStatsMap(attrRight) = newStatsRight
}
@@ -849,17 +866,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) {
def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a)
/**
- * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in
- * originalMap, because updatedMap has the latest (updated) column stats.
+ * Gets an Option of column stat for the given attribute.
+ * Prefer the column stat in updatedMap than that in originalMap,
+ * because updatedMap has the latest (updated) column stats.
*/
- def apply(a: Attribute): ColumnStat = {
+ def get(a: Attribute): Option[ColumnStat] = {
if (updatedMap.contains(a.exprId)) {
- updatedMap(a.exprId)._2
+ updatedMap.get(a.exprId).map(_._2)
} else {
- originalMap(a)
+ originalMap.get(a)
}
}
+ def hasCountStats(a: Attribute): Boolean =
+ get(a).map(_.hasCountStats).getOrElse(false)
+
+ def hasDistinctCount(a: Attribute): Boolean =
+ get(a).map(_.distinctCount.isDefined).getOrElse(false)
+
+ def hasMinMaxStats(a: Attribute): Boolean =
+ get(a).map(_.hasCountStats).getOrElse(false)
+
+ /**
+ * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in
+ * originalMap, because updatedMap has the latest (updated) column stats.
+ */
+ def apply(a: Attribute): ColumnStat = {
+ get(a).get
+ }
+
/** Updates column stats in updatedMap. */
def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats)
@@ -871,11 +906,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) {
: AttributeMap[ColumnStat] = {
val newColumnStats = originalMap.map { case (attr, oriColStat) =>
val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat)
- val newNdv = if (colStat.distinctCount > 1) {
+ val newNdv = if (colStat.distinctCount.isEmpty) {
+ // No NDV in the original stats.
+ None
+ } else if (colStat.distinctCount.get > 1) {
// Update ndv based on the overall filter selectivity: scale down ndv if the number of rows
// decreases; otherwise keep it unchanged.
- EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
- newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount)
+ Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
+ newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get))
} else {
// no need to scale down since it is already down to 1 (for skewed distribution case)
colStat.distinctCount
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index f0294a4246703..2543e38a92c0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging {
// 3. Update statistics based on the output of join
val inputAttrStats = AttributeMap(
leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq)
- val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a))
+ val attributesWithStat = join.output.filter(a =>
+ inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false))
val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_))
val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) {
@@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging {
case FullOuter =>
fromLeft.map { a =>
val oriColStat = inputAttrStats(a)
- (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows))
+ (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows)))
} ++ fromRight.map { a =>
val oriColStat = inputAttrStats(a)
- (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows))
+ (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows)))
}
case _ =>
assert(joinType == Inner || joinType == Cross)
@@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging {
private def computeByNdv(
leftKey: AttributeReference,
rightKey: AttributeReference,
- newMin: Option[Any],
- newMax: Option[Any]): (BigInt, ColumnStat) = {
+ min: Option[Any],
+ max: Option[Any]): (BigInt, ColumnStat) = {
val leftKeyStat = leftStats.attributeStats(leftKey)
val rightKeyStat = rightStats.attributeStats(rightKey)
- val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount)
+ val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get)
// Compute cardinality by the basic formula.
val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv)
// Get the intersected column stat.
- val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount)
- val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
- val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
- val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
+ val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get))
+ val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) {
+ Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get))
+ } else {
+ None
+ }
+ val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) {
+ Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2)
+ } else {
+ None
+ }
+ val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen)
(ceil(card), newStats)
}
@@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging {
val leftKeyStat = leftStats.attributeStats(leftKey)
val rightKeyStat = rightStats.attributeStats(rightKey)
- val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
- val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
- val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen)
+ val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) {
+ Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get))
+ } else {
+ None
+ }
+ val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) {
+ Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2)
+ } else {
+ None
+ }
+ val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen)
(ceil(card), newStats)
}
@@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging {
} else {
val oldColStat = oldAttrStats(a)
val oldNdv = oldColStat.distinctCount
- val newNdv = if (join.left.outputSet.contains(a)) {
- updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv)
+ val newNdv = if (oldNdv.isDefined) {
+ Some(if (join.left.outputSet.contains(a)) {
+ updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get)
+ } else {
+ updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get)
+ })
} else {
- updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv)
+ None
}
val newColStat = oldColStat.copy(distinctCount = newNdv)
// TODO: support nullCount updates for specific outer joins
@@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging {
// Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to
// support it in the future by using `nullCount` in column stats.
case (lk: AttributeReference, rk: AttributeReference)
- if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk)
+ if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index 85f67c7d66075..ee43f9126386b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
private def visitUnaryNode(p: UnaryNode): Statistics = {
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
- val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
- val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
+ val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
+ val outputRowSize = EstimationUtils.getSizePerRow(p.output)
// Assume there will be the same number of rows as child has.
var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
if (sizeInBytes == 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 9c7d47f99ee10..becfa8d982213 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
var changed = false
val remainingNewChildren = newChildren.toBuffer
val remainingOldChildren = children.toBuffer
+ def mapTreeNode(node: TreeNode[_]): TreeNode[_] = {
+ val newChild = remainingNewChildren.remove(0)
+ val oldChild = remainingOldChildren.remove(0)
+ if (newChild fastEquals oldChild) {
+ oldChild
+ } else {
+ changed = true
+ newChild
+ }
+ }
+ def mapChild(child: Any): Any = child match {
+ case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }
val newArgs = mapProductIterator {
case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
// Handle Seq[TreeNode] in TreeNode parameters.
- case s: Seq[_] => s.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = remainingNewChildren.remove(0)
- val oldChild = remainingOldChildren.remove(0)
- if (newChild fastEquals oldChild) {
- oldChild
- } else {
- changed = true
- newChild
- }
- case nonChild: AnyRef => nonChild
- case null => null
- }
- case m: Map[_, _] => m.mapValues {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = remainingNewChildren.remove(0)
- val oldChild = remainingOldChildren.remove(0)
- if (newChild fastEquals oldChild) {
- oldChild
- } else {
- changed = true
- newChild
- }
- case nonChild: AnyRef => nonChild
- case null => null
- }.view.force // `mapValues` is lazy and we need to force it to materialize
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = remainingNewChildren.remove(0)
- val oldChild = remainingOldChildren.remove(0)
- if (newChild fastEquals oldChild) {
- oldChild
- } else {
- changed = true
- newChild
- }
+ case s: Stream[_] =>
+ // Stream is lazy so we need to force materialization
+ s.map(mapChild).force
+ case s: Seq[_] =>
+ s.map(mapChild)
+ case m: Map[_, _] =>
+ // `mapValues` is lazy and we need to force it to materialize
+ m.mapValues(mapChild).view.force
+ case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
case nonChild: AnyRef => nonChild
case null => null
}
@@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def mapChildren(f: BaseType => BaseType): BaseType = {
if (children.nonEmpty) {
var changed = false
+ def mapChild(child: Any): Any = child match {
+ case arg: TreeNode[_] if containsChild(arg) =>
+ val newChild = f(arg.asInstanceOf[BaseType])
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
+ val newChild1 = if (containsChild(arg1)) {
+ f(arg1.asInstanceOf[BaseType])
+ } else {
+ arg1.asInstanceOf[BaseType]
+ }
+
+ val newChild2 = if (containsChild(arg2)) {
+ f(arg2.asInstanceOf[BaseType])
+ } else {
+ arg2.asInstanceOf[BaseType]
+ }
+
+ if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+ changed = true
+ (newChild1, newChild2)
+ } else {
+ tuple
+ }
+ case other => other
+ }
+
val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = f(arg.asInstanceOf[BaseType])
@@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case other => other
}.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
- case args: Traversable[_] => args.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = f(arg.asInstanceOf[BaseType])
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
- val newChild1 = if (containsChild(arg1)) {
- f(arg1.asInstanceOf[BaseType])
- } else {
- arg1.asInstanceOf[BaseType]
- }
-
- val newChild2 = if (containsChild(arg2)) {
- f(arg2.asInstanceOf[BaseType])
- } else {
- arg2.asInstanceOf[BaseType]
- }
-
- if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
- changed = true
- (newChild1, newChild2)
- } else {
- tuple
- }
- case other => other
- }
+ case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
+ case args: Traversable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 9beef41d639f3..104b428614849 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util
import scala.reflect.ClassTag
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
object ArrayData {
def toArrayData(input: Any): ArrayData = input match {
@@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
def array: Array[Any]
+ def toSeq[T](dataType: DataType): IndexedSeq[T] =
+ new ArrayDataIndexedSeq[T](this, dataType)
+
def setNullAt(i: Int): Unit
def update(i: Int, value: Any): Unit
@@ -137,30 +141,55 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
def toArray[T: ClassTag](elementType: DataType): Array[T] = {
val size = numElements()
+ val accessor = InternalRow.getAccessor(elementType)
val values = new Array[T](size)
var i = 0
while (i < size) {
if (isNullAt(i)) {
values(i) = null.asInstanceOf[T]
} else {
- values(i) = get(i, elementType).asInstanceOf[T]
+ values(i) = accessor(this, i).asInstanceOf[T]
}
i += 1
}
values
}
- // todo: specialize this.
def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = {
val size = numElements()
+ val accessor = InternalRow.getAccessor(elementType)
var i = 0
while (i < size) {
if (isNullAt(i)) {
f(i, null)
} else {
- f(i, get(i, elementType))
+ f(i, accessor(this, i))
}
i += 1
}
}
}
+
+/**
+ * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData`
+ * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`,
+ * instead of `IndexedSeq[Int]`, in order to keep the null elements.
+ */
+class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] {
+
+ private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType)
+
+ override def apply(idx: Int): T =
+ if (0 <= idx && idx < arrayData.numElements()) {
+ if (arrayData.isNullAt(idx)) {
+ null.asInstanceOf[T]
+ } else {
+ accessor(arrayData, idx).asInstanceOf[T]
+ }
+ } else {
+ throw new IndexOutOfBoundsException(
+ s"Index $idx must be between 0 and the length of the ArrayData.")
+ }
+
+ override def length: Int = arrayData.numElements()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index fa69b8af62c85..80f15053005ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -296,10 +296,28 @@ object DateTimeUtils {
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m`
*/
def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = {
- stringToTimestamp(s, defaultTimeZone())
+ stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false)
}
def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = {
+ stringToTimestamp(s, timeZone, rejectTzInString = false)
+ }
+
+ /**
+ * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone.
+ * Returns None if the input string is not a valid timestamp format.
+ *
+ * @param s the input timestamp string.
+ * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string
+ * already contains timezone information and `forceTimezone` is false.
+ * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the
+ * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`,
+ * return None.
+ */
+ def stringToTimestamp(
+ s: UTF8String,
+ timeZone: TimeZone,
+ rejectTzInString: Boolean): Option[SQLTimestamp] = {
if (s == null) {
return None
}
@@ -417,6 +435,8 @@ object DateTimeUtils {
return None
}
+ if (tz.isDefined && rejectTzInString) return None
+
val c = if (tz.isEmpty) {
Calendar.getInstance(timeZone)
} else {
@@ -865,29 +885,19 @@ object DateTimeUtils {
/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
- * microseconds since 1.1.1970.
- *
- * If time1 and time2 having the same day of month, or both are the last day of month,
- * it returns an integer (time under a day will be ignored).
- *
- * Otherwise, the difference is calculated based on 31 days per month, and rounding to
- * 8 digits.
- */
- def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = {
- monthsBetween(time1, time2, defaultTimeZone())
- }
-
- /**
- * Returns number of months between time1 and time2. time1 and time2 are expressed in
- * microseconds since 1.1.1970.
+ * microseconds since 1.1.1970. If time1 is later than time2, the result is positive.
*
- * If time1 and time2 having the same day of month, or both are the last day of month,
- * it returns an integer (time under a day will be ignored).
+ * If time1 and time2 are on the same day of month, or both are the last day of month,
+ * returns, time of day will be ignored.
*
- * Otherwise, the difference is calculated based on 31 days per month, and rounding to
- * 8 digits.
+ * Otherwise, the difference is calculated based on 31 days per month.
+ * The result is rounded to 8 decimal places if `roundOff` is set to true.
*/
- def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = {
+ def monthsBetween(
+ time1: SQLTimestamp,
+ time2: SQLTimestamp,
+ roundOff: Boolean,
+ timeZone: TimeZone): Double = {
val millis1 = time1 / 1000L
val millis2 = time2 / 1000L
val date1 = millisToDays(millis1, timeZone)
@@ -898,16 +908,25 @@ object DateTimeUtils {
val months1 = year1 * 12 + monthInYear1
val months2 = year2 * 12 + monthInYear2
+ val monthDiff = (months1 - months2).toDouble
+
if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) {
- return (months1 - months2).toDouble
+ return monthDiff
+ }
+ // using milliseconds can cause precision loss with more than 8 digits
+ // we follow Hive's implementation which uses seconds
+ val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L
+ val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L
+ val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2
+ // 2678400D is the number of seconds in 31 days
+ // every month is considered to be 31 days long in this function
+ val diff = monthDiff + secondsDiff / 2678400D
+ if (roundOff) {
+ // rounding to 8 digits
+ math.round(diff * 1e8) / 1e8
+ } else {
+ diff
}
- // milliseconds is enough for 8 digits precision on the right side
- val timeInDay1 = millis1 - daysToMillis(date1, timeZone)
- val timeInDay2 = millis2 - daysToMillis(date2, timeZone)
- val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY
- val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0
- // rounding to 8 digits
- math.round(diff * 1e8) / 1e8
}
// Thursday = 0 since 1970/Jan/01 => Thursday
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
index b013add9c9778..3190e511e2cb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
* See the G-K article for more details.
* @param count the count of all the elements *inserted in the sampled buffer*
* (excluding the head buffer)
+ * @param compressed whether the statistics have been compressed
*/
class QuantileSummaries(
val compressThreshold: Int,
val relativeError: Double,
val sampled: Array[Stats] = Array.empty,
- val count: Long = 0L) extends Serializable {
+ val count: Long = 0L,
+ var compressed: Boolean = false) extends Serializable {
// a buffer of latest samples seen so far
private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty
@@ -60,6 +62,7 @@ class QuantileSummaries(
*/
def insert(x: Double): QuantileSummaries = {
headSampled += x
+ compressed = false
if (headSampled.size >= defaultHeadSize) {
val result = this.withHeadBufferInserted
if (result.sampled.length >= compressThreshold) {
@@ -135,11 +138,11 @@ class QuantileSummaries(
assert(inserted.count == count + headSampled.size)
val compressed =
compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count)
- new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count)
+ new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true)
}
private def shallowCopy: QuantileSummaries = {
- new QuantileSummaries(compressThreshold, relativeError, sampled, count)
+ new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed)
}
/**
@@ -163,7 +166,7 @@ class QuantileSummaries(
val res = (sampled ++ other.sampled).sortBy(_.value)
val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count)
new QuantileSummaries(
- other.compressThreshold, other.relativeError, comp, other.count + count)
+ other.compressThreshold, other.relativeError, comp, other.count + count, true)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala
new file mode 100644
index 0000000000000..4fe07a071c1ca
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.util.UUID
+
+import org.apache.commons.math3.random.MersenneTwister
+
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class is used to generate a UUID from Pseudo-Random Numbers.
+ *
+ * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace,
+ * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers".
+ */
+case class RandomUUIDGenerator(randomSeed: Long) {
+ private val random = new MersenneTwister(randomSeed)
+
+ def getNextUUID(): UUID = {
+ val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L
+ val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL
+
+ new UUID(mostSigBits, leastSigBits)
+ }
+
+ def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString())
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala
new file mode 100644
index 0000000000000..19f67236c8979
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import java.util.{Map => JMap}
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
+
+/**
+ * A readonly SQLConf that will be created by tasks running at the executor side. It reads the
+ * configs from the local properties which are propagated from driver to executors.
+ */
+class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
+
+ @transient override val settings: JMap[String, String] = {
+ context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
+ }
+
+ @transient override protected val reader: ConfigReader = {
+ new ConfigReader(new TaskContextConfigProvider(context))
+ }
+
+ override protected def setConfWithCheck(key: String, value: String): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def unsetConf(key: String): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def unsetConf(entry: ConfigEntry[_]): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def clear(): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def clone(): SQLConf = {
+ throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
+ }
+
+ override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
+ throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
+ }
+}
+
+class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
+ override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 08a472f5ec8fb..671d19483d524 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -27,11 +27,12 @@ import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.util.Utils
@@ -95,7 +96,9 @@ object SQLConf {
/**
* Returns the active config object within the current scope. If there is an active SparkSession,
- * the proper SQLConf associated with the thread's session is used.
+ * the proper SQLConf associated with the thread's active session is used. If it's called from
+ * tasks in the executor side, a SQLConf will be created from job local properties, which are set
+ * and propagated from the driver side.
*
* The way this works is a little bit convoluted, due to the fact that config was added initially
* only for physical plans (and as a result not in sql/catalyst module).
@@ -107,7 +110,22 @@ object SQLConf {
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
- def get: SQLConf = confGetter.get()()
+ def get: SQLConf = {
+ if (TaskContext.get != null) {
+ new ReadOnlySQLConf(TaskContext.get())
+ } else {
+ if (Utils.isTesting && SparkContext.getActive.isDefined) {
+ // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter`
+ // will return `fallbackConf` which is unexpected. Here we prevent it from happening.
+ val schedulerEventLoopThread =
+ SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread
+ if (schedulerEventLoopThread.getId == Thread.currentThread().getId) {
+ throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.")
+ }
+ }
+ confGetter.get()()
+ }
+ }
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
@@ -345,7 +363,7 @@ object SQLConf {
"snappy, gzip, lzo.")
.stringConf
.transform(_.toLowerCase(Locale.ROOT))
- .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo"))
+ .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd"))
.createWithDefault("snappy")
val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown")
@@ -353,6 +371,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date")
+ .doc("If true, enables Parquet filter push-down optimization for Date. " +
+ "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
+
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
.doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " +
"versions, when converting Parquet schema to Spark SQL schema and vice versa.")
@@ -370,7 +395,7 @@ object SQLConf {
.doc("The output committer class used by Parquet. The specified class needs to be a " +
"subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " +
"of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" +
- "will never be created, irrespective of the value of parquet.enable.summary-metadata")
+ "will never be created, irrespective of the value of parquet.summary.metadata.level")
.internal()
.stringConf
.createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter")
@@ -399,7 +424,7 @@ object SQLConf {
val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl")
.doc("When native, use the native version of ORC support instead of the ORC library in Hive " +
- "1.2.1. It is 'hive' by default prior to Spark 2.3.")
+ "1.2.1. It is 'hive' by default prior to Spark 2.4.")
.internal()
.stringConf
.checkValues(Set("hive", "native"))
@@ -430,7 +455,8 @@ object SQLConf {
val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath")
.doc("When true, check all the partition paths under the table\'s root directory " +
- "when reading data stored in HDFS.")
+ "when reading data stored in HDFS. This configuration will be deprecated in the future " +
+ "releases and replaced by spark.files.ignoreMissingFiles.")
.booleanConf
.createWithDefault(false)
@@ -479,6 +505,16 @@ object SQLConf {
.checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString))
.createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString)
+ val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP =
+ buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp")
+ .internal()
+ .doc("When true (default), compare Date with Timestamp after converting both sides to " +
+ "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " +
+ "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " +
+ "converting both sides to string. This config will be removed in spark 3.0")
+ .booleanConf
+ .createWithDefault(true)
+
val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly")
.doc("When true, enable the metadata-only query optimization that use the table's metadata " +
"to produce the partition columns instead of table scans. It applies when all the columns " +
@@ -493,6 +529,14 @@ object SQLConf {
.stringConf
.createWithDefault("_corrupt_record")
+ val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema")
+ .internal()
+ .doc("When true, force the output schema of the from_json() function to be nullable " +
+ "(including all the fields). Otherwise, the schema might not be compatible with" +
+ "actual data, which leads to curruptions.")
+ .booleanConf
+ .createWithDefault(true)
+
val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout")
.doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
.timeConf(TimeUnit.SECONDS)
@@ -660,6 +704,17 @@ object SQLConf {
.intConf
.createWithDefault(100)
+ val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode")
+ .doc("This config determines the fallback behavior of several codegen generators " +
+ "during tests. `FALLBACK` means trying codegen first and then fallbacking to " +
+ "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " +
+ "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " +
+ "this config works only for tests.")
+ .internal()
+ .stringConf
+ .checkValues(CodegenObjectFactoryMode.values.map(_.toString))
+ .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString)
+
val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
.doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
@@ -893,6 +948,14 @@ object SQLConf {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefault(10000L)
+ val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED =
+ buildConf("spark.sql.streaming.noDataMicroBatchesEnabled")
+ .doc(
+ "Whether streaming micro-batch engine will execute batches without data " +
+ "for eager state management for stateful streaming queries.")
+ .booleanConf
+ .createWithDefault(true)
+
val STREAMING_METRICS_ENABLED =
buildConf("spark.sql.streaming.metricsEnabled")
.doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.")
@@ -905,6 +968,13 @@ object SQLConf {
.intConf
.createWithDefault(100)
+ val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS =
+ buildConf("spark.sql.streaming.checkpointFileManagerClass")
+ .doc("The class used to write checkpoint files atomically. This class must be a subclass " +
+ "of the interface CheckpointFileManager.")
+ .internal()
+ .stringConf
+
val NDV_MAX_ERROR =
buildConf("spark.sql.statistics.ndv.maxError")
.internal()
@@ -1058,16 +1128,23 @@ object SQLConf {
.intConf
.createWithDefault(100)
- val ARROW_EXECUTION_ENABLE =
+ val ARROW_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas, and " +
"pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " +
"The following data types are unsupported: " +
- "MapType, ArrayType of TimestampType, and nested StructType.")
+ "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.")
.booleanConf
.createWithDefault(false)
+ val ARROW_FALLBACK_ENABLED =
+ buildConf("spark.sql.execution.arrow.fallback.enabled")
+ .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " +
+ "fallback automatically to non-optimized implementations if an error occurs.")
+ .booleanConf
+ .createWithDefault(true)
+
val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
@@ -1117,8 +1194,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val SQL_OPTIONS_REDACTION_PATTERN =
+ buildConf("spark.sql.redaction.options.regex")
+ .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " +
+ "information. The values of options whose names that match this regex will be redacted " +
+ "in the explain output. This redaction is applied on top of the global redaction " +
+ s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.")
+ .regexConf
+ .createWithDefault("(?i)url".r)
+
val SQL_STRING_REDACTION_PATTERN =
- ConfigBuilder("spark.sql.redaction.string.regex")
+ buildConf("spark.sql.redaction.string.regex")
.doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
"information. When this regex matches a string part, that string part is replaced by a " +
"dummy value. This is currently used to redact the output of SQL explain commands. " +
@@ -1137,6 +1223,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
+ buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation")
+ .internal()
+ .doc("When this option is set to true, creating managed tables with nonempty location " +
+ "is allowed. Otherwise, an analysis exception is thrown. ")
+ .booleanConf
+ .createWithDefault(false)
+
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
@@ -1156,10 +1250,27 @@ object SQLConf {
val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
.internal()
.doc("A comma-separated list of fully qualified data source register class names for which" +
- " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.")
+ " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.")
.stringConf
.createWithDefault("")
+ val DISABLED_V2_STREAMING_MICROBATCH_READERS =
+ buildConf("spark.sql.streaming.disabledV2MicroBatchReaders")
+ .internal()
+ .doc(
+ "A comma-separated list of fully qualified data source register class names for which " +
+ "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " +
+ "V1 Sources.")
+ .stringConf
+ .createWithDefault("")
+
+ val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString")
+ .internal()
+ .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " +
+ "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.")
+ .booleanConf
+ .createWithDefault(true)
+
object PartitionOverwriteMode extends Enumeration {
val STATIC, DYNAMIC = Value
}
@@ -1190,6 +1301,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val TOP_K_SORT_FALLBACK_THRESHOLD =
+ buildConf("spark.sql.execution.topKSortFallbackThreshold")
+ .internal()
+ .doc("In SQL queries with a SORT followed by a LIMIT like " +
+ "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" +
+ " in memory, otherwise do a global sort which spills to disk if necessary.")
+ .intConf
+ .createWithDefault(Int.MaxValue)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -1197,6 +1317,13 @@ object SQLConf {
object Replaced {
val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces"
}
+
+ val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled")
+ .internal()
+ .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " +
+ "Other column values can be ignored during parsing even if they are malformed.")
+ .booleanConf
+ .createWithDefault(true)
}
/**
@@ -1211,17 +1338,11 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._
- if (Utils.isTesting && SparkEnv.get != null) {
- // assert that we're only accessing it on the driver.
- assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
- "SQLConf should only be created and accessed on the driver.")
- }
-
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
- @transient private val reader = new ConfigReader(settings)
+ @transient protected val reader = new ConfigReader(settings)
/** ************************ Spark SQL Params/Hints ******************* */
@@ -1258,6 +1379,9 @@ class SQLConf extends Serializable with Logging {
def streamingNoDataProgressEventInterval: Long =
getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL)
+ def streamingNoDataMicroBatchesEnabled: Boolean =
+ getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED)
+
def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED)
def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION)
@@ -1304,6 +1428,8 @@ class SQLConf extends Serializable with Logging {
def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
+ def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED)
+
def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH)
@@ -1317,6 +1443,9 @@ class SQLConf extends Serializable with Logging {
def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value =
HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE))
+ def compareDateTimestampInTimestamp : Boolean =
+ getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP)
+
def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT)
def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY)
@@ -1349,10 +1478,12 @@ class SQLConf extends Serializable with Logging {
def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR)
- def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
+ def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN)
def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
+ def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD)
+
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
@@ -1518,7 +1649,9 @@ class SQLConf extends Serializable with Logging {
def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
- def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)
+ def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED)
+
+ def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED)
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
@@ -1537,13 +1670,21 @@ class SQLConf extends Serializable with Logging {
def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)
+ def disabledV2StreamingMicroBatchReaders: String =
+ getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS)
+
def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
+ def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
+ getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
+
def partitionOverwriteMode: PartitionOverwriteMode.Value =
PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
+ def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
@@ -1650,6 +1791,17 @@ class SQLConf extends Serializable with Logging {
}.toSeq
}
+ /**
+ * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN.
+ */
+ def redactOptions(options: Map[String, String]): Map[String, String] = {
+ val regexes = Seq(
+ getConf(SQL_OPTIONS_REDACTION_PATTERN),
+ SECRET_REDACTION_PATTERN.readFrom(reader))
+
+ regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap
+ }
+
/**
* Return whether a given key is set in this [[SQLConf]].
*/
@@ -1657,7 +1809,7 @@ class SQLConf extends Serializable with Logging {
settings.containsKey(key)
}
- private def setConfWithCheck(key: String, value: String): Unit = {
+ protected def setConfWithCheck(key: String, value: String): Unit = {
settings.put(key, value)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index fe0ad39c29025..382ef28f49a7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -96,6 +96,14 @@ object StaticSQLConf {
.toSequence
.createOptional
+ val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners")
+ .doc("List of class names implementing StreamingQueryListener that will be automatically " +
+ "added to newly created sessions. The classes should have either a no-arg constructor, " +
+ "or a constructor that expects a SparkConf argument.")
+ .stringConf
+ .toSequence
+ .createOptional
+
val UI_RETAINED_EXECUTIONS =
buildStaticConf("spark.sql.ui.retainedExecutions")
.doc("Number of executions to retain in the Spark UI.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index d6e0df12218ad..fd40741cfb5f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.types
import java.util.Locale
+import scala.util.control.NonFatal
+
import org.json4s._
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
@@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
@@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType {
@InterfaceStability.Stable
object DataType {
+ def fromDDL(ddl: String): DataType = {
+ try {
+ CatalystSqlParser.parseDataType(ddl)
+ } catch {
+ case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl)
+ }
+ }
+
def fromJson(json: String): DataType = parseDataType(parse(json))
private val nonDecimalNameToType = {
@@ -295,25 +306,31 @@ object DataType {
}
/**
- * Returns true if the two data types share the same "shape", i.e. the types (including
- * nullability) are the same, but the field names don't need to be the same.
+ * Returns true if the two data types share the same "shape", i.e. the types
+ * are the same, but the field names don't need to be the same.
+ *
+ * @param ignoreNullability whether to ignore nullability when comparing the types
*/
- def equalsStructurally(from: DataType, to: DataType): Boolean = {
+ def equalsStructurally(
+ from: DataType,
+ to: DataType,
+ ignoreNullability: Boolean = false): Boolean = {
(from, to) match {
case (left: ArrayType, right: ArrayType) =>
equalsStructurally(left.elementType, right.elementType) &&
- left.containsNull == right.containsNull
+ (ignoreNullability || left.containsNull == right.containsNull)
case (left: MapType, right: MapType) =>
equalsStructurally(left.keyType, right.keyType) &&
equalsStructurally(left.valueType, right.valueType) &&
- left.valueContainsNull == right.valueContainsNull
+ (ignoreNullability || left.valueContainsNull == right.valueContainsNull)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields)
.forall { case (l, r) =>
- equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable
+ equalsStructurally(l.dataType, r.dataType) &&
+ (ignoreNullability || l.nullable == r.nullable)
}
case (fromDataType, toDataType) => fromDataType == toDataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index ef3b67c0d48d0..dbf51c398fa47 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -161,13 +161,17 @@ object DecimalType extends AbstractDataType {
* This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
*/
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
- // Assumptions:
+ // Assumption:
assert(precision >= scale)
- assert(scale >= 0)
if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
+ } else if (scale < 0) {
+ // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
+ // loss since we would cause a loss of digits in the integer part.
+ // In this case, we are likely to meet an overflow.
+ DecimalType(MAX_PRECISION, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index e3b0969283a84..362676b252126 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/** Returns all field names in an array. */
def fieldNames: Array[String] = fields.map(_.name)
+ /**
+ * Returns all field names in an array. This is an alias of `fieldNames`.
+ *
+ * @since 2.4.0
+ */
+ def names: Array[String] = fieldNames
+
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
@@ -264,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
def apply(name: String): StructField = {
nameToField.getOrElse(name,
- throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ throw new IllegalArgumentException(
+ s"""Field "$name" does not exist.
+ |Available fields: ${fieldNames.mkString(", ")}""".stripMargin))
}
/**
@@ -277,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
val nonExistFields = names -- fieldNamesSet
if (nonExistFields.nonEmpty) {
throw new IllegalArgumentException(
- s"Field ${nonExistFields.mkString(",")} does not exist.")
+ s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}.
+ |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)
}
// Preserve the original order of fields.
StructType(fields.filter(f => names.contains(f.name)))
@@ -290,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
def fieldIndex(name: String): Int = {
nameToIndex.getOrElse(name,
- throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ throw new IllegalArgumentException(
+ s"""Field "$name" does not exist.
+ |Available fields: ${fieldNames.mkString(", ")}""".stripMargin))
}
private[sql] def getFieldIndex(name: String): Option[Int] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 5a944e763e099..6af16e2dba105 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
override def catalogString: String = sqlType.simpleString
}
+private[spark] object UserDefinedType {
+ /**
+ * Get the sqlType of a (potential) [[UserDefinedType]].
+ */
+ def sqlType(dt: DataType): DataType = dt match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case _ => dt
+ }
+}
+
/**
* The user defined type in Python.
*
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java
index b67c6f3e6e85e..76930f9368514 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.Assert;
import org.junit.Test;
@@ -53,7 +54,7 @@ public void testKnownStringAndIntInputs() {
for (int i = 0; i < inputs.length; i++) {
UTF8String s = UTF8String.fromString("val_" + inputs[i]);
- int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes());
+ int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock());
Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash));
}
}
@@ -89,13 +90,13 @@ public void randomizedStressTestBytes() {
int byteArrSize = rand.nextInt(100) * 8;
byte[] bytes = new byte[byteArrSize];
rand.nextBytes(bytes);
+ MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes);
Assert.assertEquals(
- HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize),
- HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ HiveHasher.hashUnsafeBytesBlock(mb),
+ HiveHasher.hashUnsafeBytesBlock(mb));
- hashcodes.add(HiveHasher.hashUnsafeBytes(
- bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb));
}
// A very loose bound.
@@ -112,13 +113,13 @@ public void randomizedStressTestPaddedStrings() {
byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8);
byte[] paddedBytes = new byte[byteArrSize];
System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length);
+ MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes);
Assert.assertEquals(
- HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize),
- HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ HiveHasher.hashUnsafeBytesBlock(mb),
+ HiveHasher.hashUnsafeBytesBlock(mb));
- hashcodes.add(HiveHasher.hashUnsafeBytes(
- paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb));
}
// A very loose bound.
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
index fb3dbe8ed1996..2da87113c6229 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
@@ -27,7 +27,6 @@
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
import org.apache.spark.unsafe.types.UTF8String;
@@ -55,36 +54,27 @@ private String getRandomString(int length) {
}
private UnsafeRow makeKeyRow(long k1, String k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 32);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, UTF8String.fromString(k2));
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeKeyRow(long k1, long k2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, k1);
writer.write(1, k2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow makeValueRow(long v1, long v2) {
- UnsafeRow row = new UnsafeRow(2);
- BufferHolder holder = new BufferHolder(row, 0);
- UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2);
- holder.reset();
+ UnsafeRowWriter writer = new UnsafeRowWriter(2);
+ writer.reset();
writer.write(0, v1);
writer.write(1, v2);
- row.setTotalSize(holder.totalSize());
- return row;
+ return writer.getRow();
}
private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) {
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java
index 711887f02832a..cd8bce623c5df 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java
@@ -24,6 +24,8 @@
import java.util.Set;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.junit.Assert;
import org.junit.Test;
@@ -74,9 +76,6 @@ public void testKnownByteArrayInputs() {
Assert.assertEquals(0x739840CB819FA723L,
XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME));
- // These tests currently fail in a big endian environment because the test data and expected
- // answers are generated with little endian the assumptions. We could revisit this when Platform
- // becomes endian aware.
if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) {
Assert.assertEquals(0x9256E58AA397AEF1L,
hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4));
@@ -94,6 +93,23 @@ public void testKnownByteArrayInputs() {
hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE));
Assert.assertEquals(0xCAA65939306F1E21L,
XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME));
+ } else {
+ Assert.assertEquals(0x7F875412350ADDDCL,
+ hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4));
+ Assert.assertEquals(0x564D279F524D8516L,
+ XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME));
+ Assert.assertEquals(0x7D9F07E27E0EB006L,
+ hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8));
+ Assert.assertEquals(0x893CEF564CB7858L,
+ XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME));
+ Assert.assertEquals(0xC6198C4C9CC49E17L,
+ hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14));
+ Assert.assertEquals(0x4E21BEF7164D4BBL,
+ XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME));
+ Assert.assertEquals(0xBCF5FAEDEE1F2B5AL,
+ hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE));
+ Assert.assertEquals(0x6F680C877A358FE5L,
+ XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME));
}
}
@@ -128,13 +144,13 @@ public void randomizedStressTestBytes() {
int byteArrSize = rand.nextInt(100) * 8;
byte[] bytes = new byte[byteArrSize];
rand.nextBytes(bytes);
+ MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes);
Assert.assertEquals(
- hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize),
- hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hasher.hashUnsafeWordsBlock(mb),
+ hasher.hashUnsafeWordsBlock(mb));
- hashcodes.add(hasher.hashUnsafeWords(
- bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hashcodes.add(hasher.hashUnsafeWordsBlock(mb));
}
// A very loose bound.
@@ -151,13 +167,13 @@ public void randomizedStressTestPaddedStrings() {
byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8);
byte[] paddedBytes = new byte[byteArrSize];
System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length);
+ MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes);
Assert.assertEquals(
- hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize),
- hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hasher.hashUnsafeWordsBlock(mb),
+ hasher.hashUnsafeWordsBlock(mb));
- hashcodes.add(hasher.hashUnsafeWords(
- paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize));
+ hashcodes.add(hasher.hashUnsafeWordsBlock(mb));
}
// A very loose bound.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
index 2d94b66a1e122..9a89e6290e695 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala
@@ -40,7 +40,7 @@ object HashBenchmark {
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
).toArray
- val benchmark = new Benchmark("Hash For " + name, iters * numRows)
+ val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong)
benchmark.addCase("interpreted version") { _: Int =>
var sum = 0
for (_ <- 0L until iters) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala
index 2a753a0c84ed5..f6c8111f5bc57 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala
@@ -36,7 +36,8 @@ object HashByteArrayBenchmark {
bytes
}
- val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays)
+ val benchmark =
+ new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong)
benchmark.addCase("Murmur3_x86_32") { _: Int =>
var sum = 0L
for (_ <- 0L until iters) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
index 769addf3b29e6..6c63769945312 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala
@@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark {
val iters = 1024 * 16
val numRows = 1024 * 16
- val benchmark = new Benchmark("unsafe projection", iters * numRows)
+ val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong)
val schema1 = new StructType().add("l", LongType, false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index f3702ec92b425..f99af9b84d959 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -94,4 +94,49 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
== doubleGenericArray)
}
+
+ test("converting a wrong value to the struct type") {
+ val structType = new StructType().add("f1", IntegerType)
+ val exception = intercept[IllegalArgumentException] {
+ CatalystTypeConverters.createToCatalystConverter(structType)("test")
+ }
+ assert(exception.getMessage.contains("The value (test) of the type "
+ + "(java.lang.String) cannot be converted to struct"))
+ }
+
+ test("converting a wrong value to the map type") {
+ val mapType = MapType(StringType, IntegerType, false)
+ val exception = intercept[IllegalArgumentException] {
+ CatalystTypeConverters.createToCatalystConverter(mapType)("test")
+ }
+ assert(exception.getMessage.contains("The value (test) of the type "
+ + "(java.lang.String) cannot be converted to a map type with key "
+ + "type (string) and value type (int)"))
+ }
+
+ test("converting a wrong value to the array type") {
+ val arrayType = ArrayType(IntegerType, true)
+ val exception = intercept[IllegalArgumentException] {
+ CatalystTypeConverters.createToCatalystConverter(arrayType)("test")
+ }
+ assert(exception.getMessage.contains("The value (test) of the type "
+ + "(java.lang.String) cannot be converted to an array of int"))
+ }
+
+ test("converting a wrong value to the decimal type") {
+ val decimalType = DecimalType(10, 0)
+ val exception = intercept[IllegalArgumentException] {
+ CatalystTypeConverters.createToCatalystConverter(decimalType)("test")
+ }
+ assert(exception.getMessage.contains("The value (test) of the type "
+ + "(java.lang.String) cannot be converted to decimal(10,0)"))
+ }
+
+ test("converting a wrong value to the string type") {
+ val exception = intercept[IllegalArgumentException] {
+ CatalystTypeConverters.createToCatalystConverter(StringType)(0.1)
+ }
+ assert(exception.getMessage.contains("The value (0.1) of the type "
+ + "(java.lang.Double) cannot be converted to the string type"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 8c3db48a01f12..353b8344658f2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("_2", NullType, nullable = true))),
nullable = true))
}
+
+ test("SPARK-23835: add null check to non-nullable types in Tuples") {
+ def numberOfCheckedArguments(deserializer: Expression): Int = {
+ assert(deserializer.isInstanceOf[NewInstance])
+ deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
+ }
+ assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
+ assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
+ assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index c86dc18dfa680..bd87ca6017e99 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
}
}
+ test("SPARK-24468: operations on decimals with negative scale") {
+ val a = AttributeReference("a", DecimalType(3, -10))()
+ val b = AttributeReference("b", DecimalType(1, -1))()
+ val c = AttributeReference("c", DecimalType(35, 1))()
+ checkType(Multiply(a, b), DecimalType(5, -11))
+ checkType(Multiply(a, c), DecimalType(38, -9))
+ checkType(Multiply(b, c), DecimalType(37, 0))
+ }
+
/** strength reduction for integer/decimal comparisons */
def ruleTest(initial: Expression, transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala
new file mode 100644
index 0000000000000..fe57c199b8744
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+
+/**
+ * Test suite for resolving Uuid expressions.
+ */
+class ResolvedUuidExpressionsSuite extends AnalysisTest {
+
+ private lazy val a = 'a.int
+ private lazy val r = LocalRelation(a)
+ private lazy val uuid1 = Uuid().as('_uuid1)
+ private lazy val uuid2 = Uuid().as('_uuid2)
+ private lazy val uuid3 = Uuid().as('_uuid3)
+ private lazy val uuid1Ref = uuid1.toAttribute
+
+ private val analyzer = getAnalyzer(caseSensitive = true)
+
+ private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
+ plan.flatMap {
+ case p =>
+ p.expressions.flatMap(_.collect {
+ case u: Uuid => u
+ })
+ }
+ }
+
+ test("analyzed plan sets random seed for Uuid expression") {
+ val plan = r.select(a, uuid1)
+ val resolvedPlan = analyzer.executeAndCheck(plan)
+ getUuidExpressions(resolvedPlan).foreach { u =>
+ assert(u.resolved)
+ assert(u.randomSeed.isDefined)
+ }
+ }
+
+ test("Uuid expressions should have different random seeds") {
+ val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
+ val resolvedPlan = analyzer.executeAndCheck(plan)
+ assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)
+ }
+
+ test("Different analyzed plans should have different random seeds in Uuids") {
+ val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
+ val resolvedPlan1 = analyzer.executeAndCheck(plan)
+ val resolvedPlan2 = analyzer.executeAndCheck(plan)
+ val uuids1 = getUuidExpressions(resolvedPlan1)
+ val uuids2 = getUuidExpressions(resolvedPlan2)
+ assert(uuids1.distinct.length == 3)
+ assert(uuids2.distinct.length == 3)
+ assert(uuids1.intersect(uuids2).length == 0)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 52a7ebdafd7c7..0acd3b490447d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest {
Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))),
isSymmetric = false)
}
+
+ widenTest(
+ ArrayType(IntegerType, containsNull = true),
+ ArrayType(IntegerType, containsNull = false),
+ Some(ArrayType(IntegerType, containsNull = true)))
+
+ widenTest(
+ MapType(IntegerType, StringType, valueContainsNull = true),
+ MapType(IntegerType, StringType, valueContainsNull = false),
+ Some(MapType(IntegerType, StringType, valueContainsNull = true)))
+
+ widenTest(
+ new StructType()
+ .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false),
+ new StructType()
+ .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true),
+ Some(new StructType()
+ .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true)))
}
test("wider common type for decimal and array") {
@@ -506,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._
- ruleTest(TypeCoercion.ImplicitTypeCasts,
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))
- ruleTest(TypeCoercion.ImplicitTypeCasts,
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}
@@ -518,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for binary operators") {
import TypeCoercionSuite._
- ruleTest(TypeCoercion.ImplicitTypeCasts,
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
- ruleTest(TypeCoercion.ImplicitTypeCasts,
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}
@@ -539,6 +557,9 @@ class TypeCoercionSuite extends AnalysisTest {
val floatLit = Literal.create(1.0f, FloatType)
val timestampLit = Literal.create("2017-04-12", TimestampType)
val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000"))
+ val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis())))
+ val strArrayLit = Literal(Array("c"))
+ val intArrayLit = Literal(Array(1))
ruleTest(rule,
Coalesce(Seq(doubleLit, intLit, floatLit)),
@@ -572,6 +593,16 @@ class TypeCoercionSuite extends AnalysisTest {
Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)),
Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType),
Cast(doubleLit, StringType), Cast(stringLit, StringType))))
+
+ ruleTest(rule,
+ Coalesce(Seq(timestampLit, intLit, stringLit)),
+ Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType),
+ Cast(stringLit, StringType))))
+
+ ruleTest(rule,
+ Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)),
+ Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)),
+ Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType)))))
}
test("CreateArray casts") {
@@ -792,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("type coercion for CaseKeyWhen") {
- ruleTest(TypeCoercion.ImplicitTypeCasts,
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
@@ -1207,7 +1238,7 @@ class TypeCoercionSuite extends AnalysisTest {
*/
test("make sure rules do not fire early") {
// InConversion
- val inConversion = TypeCoercion.InConversion
+ val inConversion = TypeCoercion.InConversion(conf)
ruleTest(inConversion,
In(UnresolvedAttribute("a"), Seq(Literal(1))),
In(UnresolvedAttribute("a"), Seq(Literal(1)))
@@ -1244,25 +1275,47 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("SPARK-17117 null type coercion in divide") {
- val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
+ val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
}
test("binary comparison with string promotion") {
- ruleTest(PromoteStrings,
+ val rule = TypeCoercion.PromoteStrings(conf)
+ ruleTest(rule,
GreaterThan(Literal("123"), Literal(1)),
GreaterThan(Cast(Literal("123"), IntegerType), Literal(1)))
- ruleTest(PromoteStrings,
+ ruleTest(rule,
LessThan(Literal(true), Literal("123")),
LessThan(Literal(true), Cast(Literal("123"), BooleanType)))
- ruleTest(PromoteStrings,
+ ruleTest(rule,
EqualTo(Literal(Array(1, 2)), Literal("123")),
EqualTo(Literal(Array(1, 2)), Literal("123")))
- ruleTest(PromoteStrings,
+ ruleTest(rule,
GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))),
- GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType)))
+ GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")),
+ DoubleType)))
+ Seq(true, false).foreach { convertToTS =>
+ withSQLConf(
+ "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) {
+ val date0301 = Literal(java.sql.Date.valueOf("2017-03-01"))
+ val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00"))
+ val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01"))
+ if (convertToTS) {
+ // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549
+ ruleTest(rule, EqualTo(date0301, timestamp0301000000),
+ EqualTo(Cast(date0301, TimestampType), timestamp0301000000))
+ ruleTest(rule, LessThan(date0301, timestamp0301000001),
+ LessThan(Cast(date0301, TimestampType), timestamp0301000001))
+ } else {
+ ruleTest(rule, LessThan(date0301, timestamp0301000000),
+ LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType)))
+ ruleTest(rule, LessThan(date0301, timestamp0301000001),
+ LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType)))
+ }
+ }
+ }
}
test("cast WindowFrame boundaries to the type they operate upon") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 60d1351fda264..cb487c8893541 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Append,
expectedMsgs = Seq("monotonically_increasing_id"))
+ assertSupportedForContinuousProcessing(
+ "TypedFilter", TypedFilter(
+ null,
+ null,
+ null,
+ null,
+ new TestStreamingRelationV2(attribute)), OutputMode.Append())
/*
=======================================================================================
@@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
}
}
+ /** Assert that the logical plan is supported for continuous procsssing mode */
+ def assertSupportedForContinuousProcessing(
+ name: String,
+ plan: LogicalPlan,
+ outputMode: OutputMode): Unit = {
+ test(s"continuous processing - $name: supported") {
+ UnsupportedOperationChecker.checkForContinuous(plan, outputMode)
+ }
+ }
+
/**
* Assert that the logical plan is not supported inside a streaming plan.
*
@@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
def this(attribute: Attribute) = this(Seq(attribute))
override def isStreaming: Boolean = true
}
+
+ case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode {
+ def this(attribute: Attribute) = this(Seq(attribute))
+ override def isStreaming: Boolean = true
+ override def nodeName: String = "StreamingRelationV2"
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala
index 1acbe34d9a075..2fcaeca34db3f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala
@@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite {
private def testWithCatalog(
name: String)(
f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) {
- val catalog = newCatalog
+ val catalog = new ExternalCatalogWithListener(newCatalog)
val recorder = mutable.Buffer.empty[ExternalCatalogEvent]
catalog.addListener(new ExternalCatalogEventListener {
override def onEvent(event: ExternalCatalogEvent): Unit = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
new file mode 100644
index 0000000000000..28e6940f3cca3
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.logical.Range
+
+class CanonicalizeSuite extends SparkFunSuite {
+
+ test("SPARK-24276: IN expression with different order are semantically equal") {
+ val range = Range(1, 1, 1, 1)
+ val idAttr = range.output.head
+
+ val in1 = In(idAttr, Seq(Literal(1), Literal(2)))
+ val in2 = In(idAttr, Seq(Literal(2), Literal(1)))
+ val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3)))
+
+ assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash())
+ assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash())
+
+ assert(range.where(in1).sameResult(range.where(in2)))
+ assert(!range.where(in1).sameResult(range.where(in3)))
+
+ val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
+ CreateArray(Seq(Literal(2), Literal(1)))))
+ val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))),
+ CreateArray(Seq(Literal(1), Literal(2)))))
+ val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
+ CreateArray(Seq(Literal(3), Literal(1)))))
+
+ assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash())
+ assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash())
+
+ assert(range.where(arrays1).sameResult(range.where(arrays2)))
+ assert(!range.where(arrays1).sameResult(range.where(arrays3)))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 676ba3956ddc8..5b71becee2de0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-18016: define mutable states by using an array") {
val ctx1 = new CodegenContext
for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) {
- ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;")
+ ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;")
}
assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
// When the number of primitive type mutable states is over the threshold, others are
// allocated into an array
- assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
+ assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1)
assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)
val ctx2 = new CodegenContext
@@ -436,4 +436,67 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
ctx.addImmutableStateIfNotExists("String", mutableState2)
assert(ctx.inlinedMutableStates.length == 2)
}
+
+ test("SPARK-23628: calculateParamLength should compute properly the param length") {
+ assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
+ assert(CodeGenerator.calculateParamLength(
+ Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
+ }
+
+ test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") {
+
+ val ref = BoundReference(0, IntegerType, true)
+ val add1 = Add(ref, ref)
+ val add2 = Add(add1, add1)
+ val dummy = SubExprEliminationState(
+ JavaCode.variable("dummy", BooleanType),
+ JavaCode.variable("dummy", BooleanType))
+
+ // raw testing of basic functionality
+ {
+ val ctx = new CodegenContext
+ val e = ref.genCode(ctx)
+ // before
+ ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value)
+ assert(ctx.subExprEliminationExprs.contains(ref))
+ // call withSubExprEliminationExprs
+ ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {
+ assert(ctx.subExprEliminationExprs.contains(add1))
+ assert(!ctx.subExprEliminationExprs.contains(ref))
+ Seq.empty
+ }
+ // after
+ assert(ctx.subExprEliminationExprs.nonEmpty)
+ assert(ctx.subExprEliminationExprs.contains(ref))
+ assert(!ctx.subExprEliminationExprs.contains(add1))
+ }
+
+ // emulate an actual codegen workload
+ {
+ val ctx = new CodegenContext
+ // before
+ ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE
+ assert(ctx.subExprEliminationExprs.contains(add1))
+ // call withSubExprEliminationExprs
+ ctx.withSubExprEliminationExprs(Map(ref -> dummy)) {
+ assert(ctx.subExprEliminationExprs.contains(ref))
+ assert(!ctx.subExprEliminationExprs.contains(add1))
+ Seq.empty
+ }
+ // after
+ assert(ctx.subExprEliminationExprs.nonEmpty)
+ assert(ctx.subExprEliminationExprs.contains(add1))
+ assert(!ctx.subExprEliminationExprs.contains(ref))
+ }
+ }
+
+ test("SPARK-23986: freshName can generate duplicated names") {
+ val ctx = new CodegenContext
+ val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") ::
+ ctx.freshName("myName11") :: Nil
+ assert(names1.distinct.length == 3)
+ val names2 = ctx.freshName("a") :: ctx.freshName("a") ::
+ ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil
+ assert(names2.distinct.length == 4)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala
new file mode 100644
index 0000000000000..531ca9a87370a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, LongType}
+
+class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {
+
+ test("UnsafeProjection with codegen factory mode") {
+ val input = Seq(LongType, IntegerType)
+ .zipWithIndex.map(x => BoundReference(x._2, x._1, true))
+
+ val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
+ val obj = UnsafeProjection.createObject(input)
+ assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
+ }
+
+ val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
+ val obj = UnsafeProjection.createObject(input)
+ assert(obj.isInstanceOf[InterpretedUnsafeProjection])
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 020687e4b3a27..85e692bdc4ef1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,33 +58,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}
+ test("MapEntries") {
+ def r(values: Any*): InternalRow = create_row(values: _*)
+
+ // Primitive-type keys/values
+ val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
+ val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
+ val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))
+
+ checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
+ checkEvaluation(MapEntries(mi1), Seq.empty)
+ checkEvaluation(MapEntries(mi2), null)
+
+ // Non-primitive-type keys/values
+ val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
+ val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
+ val ms2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
+ checkEvaluation(MapEntries(ms1), Seq.empty)
+ checkEvaluation(MapEntries(ms2), null)
+ }
+
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
- val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
+ val d1 = new Decimal().set(10)
+ val d2 = new Decimal().set(100)
+ val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0)))
+ val a5 = Literal.create(Seq(null, null), ArrayType(NullType))
checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
checkEvaluation(new SortArray(a1), Seq[Integer]())
checkEvaluation(new SortArray(a2), Seq("a", "b"))
checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
+ checkEvaluation(new SortArray(a4), Seq(d1, d2))
checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2))
checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
+ checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1))
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
- checkEvaluation(new SortArray(a4), Seq(null, null))
+ checkEvaluation(new SortArray(a5), Seq(null, null))
val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))
+
+ val typeAA = ArrayType(ArrayType(IntegerType))
+ val aa1 = Array[java.lang.Integer](1, 2)
+ val aa2 = Array[java.lang.Integer](3, null, 4)
+ val arrayArray = Literal.create(Seq(aa2, aa1), typeAA)
+
+ checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2))
+
+ val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil)))
+ val aas1 = Array(create_row(1))
+ val aas2 = Array(create_row(2))
+ val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS)
+
+ checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2))
+
+ checkEvaluation(ArraySort(a0), Seq(1, 2, 3))
+ checkEvaluation(ArraySort(a1), Seq[Integer]())
+ checkEvaluation(ArraySort(a2), Seq("a", "b"))
+ checkEvaluation(ArraySort(a3), Seq("a", "b", null))
+ checkEvaluation(ArraySort(a4), Seq(d1, d2))
+ checkEvaluation(ArraySort(a5), Seq(null, null))
+ checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2)))
+ checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2))
+ checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2))
}
test("Array contains") {
@@ -104,5 +158,612 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
+
+ // binary
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+ ArrayType(BinaryType))
+ val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val be = Literal.create(Array[Byte](1, 2), BinaryType)
+ val nullBinary = Literal.create(null, BinaryType)
+
+ checkEvaluation(ArrayContains(b0, be), true)
+ checkEvaluation(ArrayContains(b1, be), false)
+ checkEvaluation(ArrayContains(b0, nullBinary), null)
+ checkEvaluation(ArrayContains(b2, be), null)
+ checkEvaluation(ArrayContains(b3, be), true)
+
+ // complex data types
+ val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+ checkEvaluation(ArrayContains(aa0, aae), true)
+ checkEvaluation(ArrayContains(aa1, aae), false)
+ }
+
+ test("ArraysOverlap") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
+ val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
+ val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
+ val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))
+
+ val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))
+
+ checkEvaluation(ArraysOverlap(a0, a1), true)
+ checkEvaluation(ArraysOverlap(a0, a2), null)
+ checkEvaluation(ArraysOverlap(a1, a2), true)
+ checkEvaluation(ArraysOverlap(a1, a3), false)
+ checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
+ checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
+ checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
+
+ checkEvaluation(ArraysOverlap(a4, a5), true)
+ checkEvaluation(ArraysOverlap(a4, a6), null)
+ checkEvaluation(ArraysOverlap(a5, a6), false)
+
+ // null handling
+ checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
+ checkEvaluation(ArraysOverlap(
+ emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false)
+ checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
+ checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
+ checkEvaluation(ArraysOverlap(
+ Literal.create(Seq(null), ArrayType(IntegerType)),
+ Literal.create(Seq(null), ArrayType(IntegerType))), null)
+
+ // arrays of binaries
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+
+ checkEvaluation(ArraysOverlap(b0, b1), true)
+ checkEvaluation(ArraysOverlap(b0, b2), false)
+
+ // arrays of complex data types
+ val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
+ ArrayType(ArrayType(StringType)))
+ val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
+ ArrayType(ArrayType(StringType)))
+ val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
+ ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(ArraysOverlap(aa0, aa1), true)
+ checkEvaluation(ArraysOverlap(aa0, aa2), false)
+
+ // null handling with complex datatypes
+ val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType))
+ val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
+ checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false)
+ checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false)
+ checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false)
+ checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false)
+ checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null)
+ checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null)
+ }
+
+ test("Slice") {
+ val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
+ val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType))
+ val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))
+
+ checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
+ checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
+ checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6))
+ checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6))
+ checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)),
+ "Unexpected value for length")
+ checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)),
+ "Unexpected value for start")
+ checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
+ checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
+ checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null)
+ checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null)
+ checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)),
+ null)
+
+ checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
+ checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
+ checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
+ checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
+ checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
+ }
+
+ test("ArrayJoin") {
+ def testArrays(
+ arrays: Seq[Expression],
+ nullReplacement: Option[Expression],
+ expected: Seq[String]): Unit = {
+ assert(arrays.length == expected.length)
+ arrays.zip(expected).foreach { case (arr, exp) =>
+ checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp)
+ }
+ }
+
+ val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)),
+ Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)),
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)),
+ Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)),
+ Literal.create(Seq[String]("a"), ArrayType(StringType)))
+
+ val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a")
+ val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a")
+ testArrays(arrays, None, withoutNullReplacement)
+ testArrays(arrays, Some(Literal("NULL")), withNullReplacement)
+
+ checkEvaluation(ArrayJoin(
+ Literal.create(null, ArrayType(StringType)), Literal(","), None), null)
+ checkEvaluation(ArrayJoin(
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal.create(null, StringType),
+ None), null)
+ checkEvaluation(ArrayJoin(
+ Literal.create(Seq[String](null), ArrayType(StringType)),
+ Literal(","),
+ Some(Literal.create(null, StringType))), null)
+ }
+
+ test("ArraysZip") {
+ val literals = Seq(
+ Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)),
+ Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)),
+ Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)),
+ Literal.create(Seq("a", null, "c"), ArrayType(StringType)),
+ Literal.create(Seq(null, false, true), ArrayType(BooleanType)),
+ Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)),
+ Literal.create(Seq(), ArrayType(NullType)),
+ Literal.create(Seq(null), ArrayType(NullType)),
+ Literal.create(Seq(192.toByte), ArrayType(ByteType)),
+ Literal.create(
+ Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))),
+ Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType))
+ )
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(1))),
+ List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(2))),
+ List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(3))),
+ List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(4))),
+ List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(5))),
+ List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(6))),
+ List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(7))),
+ List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))),
+ List(
+ Row(9001, null, -1, "a"),
+ Row(9002, 1L, -3, null),
+ Row(9003, null, 900, "c"),
+ Row(null, 4L, null, null),
+ Row(null, 11L, null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))),
+ List(
+ Row(null, 1.1, null, null, 192.toByte),
+ Row(false, null, null, null, null),
+ Row(true, 1.3, null, null, null),
+ Row(null, null, null, null, null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(9), literals(0))),
+ List(
+ Row(List(1, 2, 3), 9001),
+ Row(null, 9002),
+ Row(List(4, 5), 9003),
+ Row(List(1, null, 3), null)))
+
+ checkEvaluation(ArraysZip(Seq(literals(7), literals(10))),
+ List(Row(null, Array[Byte](1.toByte, 5.toByte))))
+
+ val longLiteral =
+ Literal.create((0 to 1000).toSeq, ArrayType(IntegerType))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)),
+ List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++
+ (3 to 1000).map { Row(null, _) }.toList)
+
+ val manyLiterals = (0 to 1000).map { _ =>
+ Literal.create(Seq(1), ArrayType(IntegerType))
+ }.toSeq
+
+ val numbers = List(
+ Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*),
+ Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*),
+ Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*),
+ Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*))
+ checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals),
+ List(numbers(0), numbers(1), numbers(2), numbers(3)))
+
+ checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null)
+ checkEvaluation(ArraysZip(Seq()), List())
+ }
+
+ test("Array Min") {
+ checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
+ checkEvaluation(
+ ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "")
+ checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null)
+ checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null)
+ checkEvaluation(
+ ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234)
+ }
+
+ test("Array max") {
+ checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10)
+ checkEvaluation(
+ ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc")
+ checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null)
+ checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null)
+ checkEvaluation(
+ ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
+ }
+
+ test("Reverse") {
+ // Primitive-type elements
+ val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
+ val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
+ val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
+ val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
+ val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
+ val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
+ val ai7 = Literal.create(null, ArrayType(IntegerType))
+
+ checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
+ checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
+ checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
+ checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
+ checkEvaluation(Reverse(ai4), Seq(null, null, null))
+ checkEvaluation(Reverse(ai5), Seq(1))
+ checkEvaluation(Reverse(ai6), Seq.empty)
+ checkEvaluation(Reverse(ai7), null)
+
+ // Non-primitive-type elements
+ val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
+ val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
+ val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
+ val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
+ val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
+ val as5 = Literal.create(Seq("a"), ArrayType(StringType))
+ val as6 = Literal.create(Seq.empty, ArrayType(StringType))
+ val as7 = Literal.create(null, ArrayType(StringType))
+ val aa = Literal.create(
+ Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
+ ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
+ checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
+ checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
+ checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
+ checkEvaluation(Reverse(as4), Seq(null, null, null))
+ checkEvaluation(Reverse(as5), Seq("a"))
+ checkEvaluation(Reverse(as6), Seq.empty)
+ checkEvaluation(Reverse(as7), null)
+ checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
+ }
+
+ test("Array Position") {
+ val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a2 = Literal.create(Seq(null), ArrayType(LongType))
+ val a3 = Literal.create(null, ArrayType(StringType))
+
+ checkEvaluation(ArrayPosition(a0, Literal(3)), 4L)
+ checkEvaluation(ArrayPosition(a0, Literal(1)), 1L)
+ checkEvaluation(ArrayPosition(a0, Literal(0)), 0L)
+ checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null)
+
+ checkEvaluation(ArrayPosition(a1, Literal("")), 2L)
+ checkEvaluation(ArrayPosition(a1, Literal("a")), 0L)
+ checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null)
+
+ checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L)
+ checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null)
+
+ checkEvaluation(ArrayPosition(a3, Literal("")), null)
+ checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
+
+ val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+ checkEvaluation(ArrayPosition(aa0, aae), 1L)
+ checkEvaluation(ArrayPosition(aa1, aae), 0L)
+ }
+
+ test("elementAt") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a2 = Literal.create(Seq(null), ArrayType(LongType))
+ val a3 = Literal.create(null, ArrayType(StringType))
+
+ intercept[Exception] {
+ checkEvaluation(ElementAt(a0, Literal(0)), null)
+ }.getMessage.contains("SQL array indices start at 1")
+ intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
+ checkEvaluation(ElementAt(a0, Literal(4)), null)
+ checkEvaluation(ElementAt(a0, Literal(-4)), null)
+
+ checkEvaluation(ElementAt(a0, Literal(1)), 1)
+ checkEvaluation(ElementAt(a0, Literal(2)), 2)
+ checkEvaluation(ElementAt(a0, Literal(3)), 3)
+ checkEvaluation(ElementAt(a0, Literal(-3)), 1)
+ checkEvaluation(ElementAt(a0, Literal(-2)), 2)
+ checkEvaluation(ElementAt(a0, Literal(-1)), 3)
+
+ checkEvaluation(ElementAt(a1, Literal(1)), null)
+ checkEvaluation(ElementAt(a1, Literal(2)), "")
+ checkEvaluation(ElementAt(a1, Literal(-2)), null)
+ checkEvaluation(ElementAt(a1, Literal(-1)), "")
+
+ checkEvaluation(ElementAt(a2, Literal(1)), null)
+
+ checkEvaluation(ElementAt(a3, Literal(1)), null)
+
+
+ val m0 =
+ Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(null, MapType(StringType, StringType))
+
+ assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure)
+
+ checkEvaluation(ElementAt(m0, Literal("d")), null)
+
+ checkEvaluation(ElementAt(m1, Literal("a")), null)
+
+ checkEvaluation(ElementAt(m0, Literal("a")), "1")
+ checkEvaluation(ElementAt(m0, Literal("b")), "2")
+ checkEvaluation(ElementAt(m0, Literal("c")), null)
+
+ checkEvaluation(ElementAt(m2, Literal("a")), null)
+
+ // test binary type as keys
+ val mb0 = Literal.create(
+ Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+ MapType(BinaryType, StringType))
+ val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+ checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null)
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
+ }
+
+ test("Concat") {
+ // Primitive-type elements
+ val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+ val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
+ val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
+ val ai4 = Literal.create(null, ArrayType(IntegerType))
+
+ checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5))
+ checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5))
+ checkEvaluation(Concat(Seq(ai4)), null)
+ checkEvaluation(Concat(Seq(ai0, ai4)), null)
+ checkEvaluation(Concat(Seq(ai4, ai0)), null)
+
+ // Non-primitive-type elements
+ val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
+ val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
+ val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
+ val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
+ val as4 = Literal.create(null, ArrayType(StringType))
+
+ val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType)))
+ val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e"))
+ checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e"))
+ checkEvaluation(Concat(Seq(as4)), null)
+ checkEvaluation(Concat(Seq(as0, as4)), null)
+ checkEvaluation(Concat(Seq(as4, as0)), null)
+
+ checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
+ }
+
+ test("Flatten") {
+ // Primitive-type test cases
+ val intArrayType = ArrayType(ArrayType(IntegerType))
+
+ // Main test cases (primitive type)
+ val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType)
+ val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType)
+
+ checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6))
+ checkEvaluation(Flatten(aim2), Seq(1, 2, 3))
+
+ // Test cases with an empty array (primitive type)
+ val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType)
+ val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType)
+ val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType)
+ val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType)
+ val aie5 = Literal.create(Seq(Seq.empty), intArrayType)
+ val aie6 = Literal.create(Seq.empty, intArrayType)
+
+ checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4))
+ checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4))
+ checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4))
+ checkEvaluation(Flatten(aie4), Seq.empty)
+ checkEvaluation(Flatten(aie5), Seq.empty)
+ checkEvaluation(Flatten(aie6), Seq.empty)
+
+ // Test cases with null elements (primitive type)
+ val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType)
+ val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType)
+ val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType)
+
+ checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null))
+ checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null))
+ checkEvaluation(Flatten(ain3), Seq(null, null, null, null))
+
+ // Test cases with a null array (primitive type)
+ val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType)
+ val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType)
+ val aia3 = Literal.create(Seq(null), intArrayType)
+ val aia4 = Literal.create(null, intArrayType)
+
+ checkEvaluation(Flatten(aia1), null)
+ checkEvaluation(Flatten(aia2), null)
+ checkEvaluation(Flatten(aia3), null)
+ checkEvaluation(Flatten(aia4), null)
+
+ // Non-primitive-type test cases
+ val strArrayType = ArrayType(ArrayType(StringType))
+ val arrArrayType = ArrayType(ArrayType(ArrayType(StringType)))
+
+ // Main test cases (non-primitive type)
+ val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType)
+ val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType)
+ val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType)
+
+ checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f"))
+ checkEvaluation(Flatten(asm2), Seq("a", "b"))
+ checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e")))
+
+ // Test cases with an empty array (non-primitive type)
+ val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType)
+ val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType)
+ val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType)
+ val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType)
+ val ase5 = Literal.create(Seq(Seq.empty), strArrayType)
+ val ase6 = Literal.create(Seq.empty, strArrayType)
+
+ checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d"))
+ checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d"))
+ checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d"))
+ checkEvaluation(Flatten(ase4), Seq.empty)
+ checkEvaluation(Flatten(ase5), Seq.empty)
+ checkEvaluation(Flatten(ase6), Seq.empty)
+
+ // Test cases with null elements (non-primitive type)
+ val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType)
+ val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType)
+ val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType)
+
+ checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null))
+ checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null))
+ checkEvaluation(Flatten(asn3), Seq(null, null, null, null))
+
+ // Test cases with a null array (non-primitive type)
+ val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType)
+ val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType)
+ val asa3 = Literal.create(Seq(null), strArrayType)
+ val asa4 = Literal.create(null, strArrayType)
+
+ checkEvaluation(Flatten(asa1), null)
+ checkEvaluation(Flatten(asa2), null)
+ checkEvaluation(Flatten(asa3), null)
+ checkEvaluation(Flatten(asa4), null)
+ }
+
+ test("ArrayRepeat") {
+ val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+ val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType))
+
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi"))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi"))
+ checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true))
+ checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1))
+ checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2))
+ checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null))
+ checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null))
+ checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2)))
+ checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola")))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
+ }
+
+ test("Array remove") {
+ val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
+ val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType))
+ val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+ val a4 = Literal.create(null, ArrayType(StringType))
+ val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType))
+ val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType))
+
+ checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2))
+ checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null)
+
+ checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b"))
+ checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b"))
+ checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c"))
+ checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b"))
+
+ checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null))
+ checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null)
+
+ checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer])
+
+ checkEvaluation(ArrayRemove(a4, Literal("a")), null)
+
+ checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null))
+ checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true))
+
+ // complex data types
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
+ Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val nullBinary = Literal.create(null, BinaryType)
+
+ val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType)
+ checkEvaluation(ArrayRemove(b0, dataToRemove1),
+ Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2)))
+ checkEvaluation(ArrayRemove(b0, nullBinary), null)
+ checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null))
+ checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2)))
+
+ val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType)))
+ val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+ checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4)))
+ checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
+ checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 84190f0bd5f7d..726193b411737 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -180,12 +180,56 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
null, null)
}
intercept[RuntimeException] {
- checkEvalutionWithUnsafeProjection(
+ checkEvaluationWithUnsafeProjection(
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
null, null)
}
}
+ test("MapFromArrays") {
+ def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
+ // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
+ scala.collection.immutable.ListMap(keys.zip(values): _*)
+ }
+
+ val intSeq = Seq(5, 10, 15, 20, 25)
+ val longSeq = intSeq.map(_.toLong)
+ val strSeq = intSeq.map(_.toString)
+ val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
+ val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
+ val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))
+
+ val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
+ val longArray = Literal.create(longSeq, ArrayType(LongType, false))
+ val strArray = Literal.create(strSeq, ArrayType(StringType, false))
+
+ val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
+ val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
+ val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))
+
+ val nullArray = Literal.create(null, ArrayType(StringType, false))
+
+ checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
+ checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
+ checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))
+
+ checkEvaluation(
+ MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq))
+ checkEvaluation(
+ MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+ checkEvaluation(
+ MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+ checkEvaluation(MapFromArrays(nullArray, nullArray), null)
+
+ intercept[RuntimeException] {
+ checkEvaluation(MapFromArrays(intWithNullArray, strArray), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(
+ MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
+ }
+ }
+
test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 786266a2c13c0..63b24fb9eb13a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType)
}
+ test("WeekDay") {
+ checkEvaluation(WeekDay(Literal.create(null, DateType)), null)
+ checkEvaluation(WeekDay(Literal(d)), 2)
+ checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2)
+ checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4)
+ checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4)
+ checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5)
+ checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4)
+ checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType)
+ }
+
test("WeekOfYear") {
checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null)
checkEvaluation(WeekOfYear(Literal(d)), 15)
@@ -453,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
MonthsBetween(
Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)),
Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)),
- timeZoneId),
- 3.94959677)
- checkEvaluation(
- MonthsBetween(
- Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)),
- Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)),
- timeZoneId),
- 0.0)
+ Literal.TrueLiteral,
+ timeZoneId = timeZoneId), 3.94959677)
checkEvaluation(
MonthsBetween(
- Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)),
- Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)),
- timeZoneId),
- -2.0)
- checkEvaluation(
- MonthsBetween(
- Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)),
- Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)),
- timeZoneId),
- 1.0)
+ Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)),
+ Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)),
+ Literal.FalseLiteral,
+ timeZoneId = timeZoneId), 3.9495967741935485)
+
+ Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff =>
+ checkEvaluation(
+ MonthsBetween(
+ Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)),
+ Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)),
+ roundOff,
+ timeZoneId = timeZoneId), 0.0)
+ checkEvaluation(
+ MonthsBetween(
+ Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)),
+ Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)),
+ roundOff,
+ timeZoneId = timeZoneId), -2.0)
+ checkEvaluation(
+ MonthsBetween(
+ Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)),
+ Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)),
+ roundOff,
+ timeZoneId = timeZoneId), 1.0)
+ }
val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00"))
val tnull = Literal.create(null, TimestampType)
- checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null)
- checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null)
- checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null)
+ checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null)
+ checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null)
+ checkEvaluation(
+ MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null)
+ checkEvaluation(
+ MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null)
checkConsistencyBetweenInterpretedAndCodegen(
- (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId),
- TimestampType, TimestampType)
+ (time1: Expression, time2: Expression, roundOff: Expression) =>
+ MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId),
+ TimestampType, TimestampType, BooleanType)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index b4c8eab19c5cc..14bfa212b5496 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.reflect.ClassTag
+
import org.scalacheck.Gen
import org.scalactic.TripleEqualsSupport.Spread
import org.scalatest.exceptions.TestFailedException
@@ -24,10 +26,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
@@ -37,32 +41,39 @@ import org.apache.spark.util.Utils
/**
* A few helper functions for expression evaluation testing. Mixin this trait to use them.
*/
-trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
+trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase {
self: SparkFunSuite =>
protected def create_row(values: Any*): InternalRow = {
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
- protected def checkEvaluation(
- expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
+ private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val resolver = ResolveTimeZone(new SQLConf)
- val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+ resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+ }
+
+ protected def checkEvaluation(
+ expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
+ // Make it as method to obtain fresh expression everytime.
+ def expr = prepareEvaluation(expression)
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
- checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow)
+ checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow)
}
checkEvaluationWithOptimization(expr, catalystValue, inputRow)
}
/**
* Check the equality between result of expression and expected value, it will handle
- * Array[Byte], Spread[Double], and MapData.
+ * Array[Byte], Spread[Double], MapData and Row.
*/
- protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = {
+ protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = {
+ val dataType = UserDefinedType.sqlType(exprDataType)
+
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
@@ -88,12 +99,48 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
+ case (result: UnsafeRow, expected: GenericInternalRow) =>
+ val structType = exprDataType.asInstanceOf[StructType]
+ result.toSeq(structType) == expected.toSeq(structType)
+ case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
}
}
- protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
+ protected def checkExceptionInExpression[T <: Throwable : ClassTag](
+ expression: => Expression,
+ expectedErrMsg: String): Unit = {
+ checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg)
+ }
+
+ protected def checkExceptionInExpression[T <: Throwable : ClassTag](
+ expression: => Expression,
+ inputRow: InternalRow,
+ expectedErrMsg: String): Unit = {
+
+ def checkException(eval: => Unit, testMode: String): Unit = {
+ withClue(s"($testMode)") {
+ val errMsg = intercept[T] {
+ eval
+ }.getMessage
+ if (!errMsg.contains(expectedErrMsg)) {
+ fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found")
+ }
+ }
+ }
+
+ // Make it as method to obtain fresh expression everytime.
+ def expr = prepareEvaluation(expression)
+ checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode")
+ checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode")
+ if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+ checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode")
+ }
+ }
+
+ protected def evaluateWithoutCodegen(
+ expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
expression.foreach {
case n: Nondeterministic => n.initialize(0)
case _ =>
@@ -122,7 +169,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
- val actual = try evaluate(expression, inputRow) catch {
+ val actual = try evaluateWithoutCodegen(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (!checkResult(actual, expected, expression.dataType)) {
@@ -137,23 +184,56 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
+ val actual = evaluateWithGeneratedMutableProjection(expression, inputRow)
+ if (!checkResult(actual, expected, expression.dataType)) {
+ val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+ protected def evaluateWithGeneratedMutableProjection(
+ expression: Expression,
+ inputRow: InternalRow = EmptyRow): Any = {
val plan = generateProject(
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
expression)
plan.initialize(0)
- val actual = plan(inputRow).get(0, expression.dataType)
- if (!checkResult(actual, expected, expression.dataType)) {
- val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
- fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input")
- }
+ plan(inputRow).get(0, expression.dataType)
}
- protected def checkEvalutionWithUnsafeProjection(
+ protected def checkEvaluationWithUnsafeProjection(
expression: Expression,
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
+ val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
+ for (fallbackMode <- modes) {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
+ val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow)
+ val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
+
+ if (expected == null) {
+ if (!unsafeRow.isNullAt(0)) {
+ val expectedRow = InternalRow(expected, expected)
+ fail("Incorrect evaluation in unsafe mode: " +
+ s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
+ }
+ } else {
+ val lit = InternalRow(expected, expected)
+ val expectedRow =
+ UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
+ if (unsafeRow != expectedRow) {
+ fail("Incorrect evaluation in unsafe mode: " +
+ s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
+ }
+ }
+ }
+ }
+ }
+
+ protected def evaluateWithUnsafeProjection(
+ expression: Expression,
+ inputRow: InternalRow = EmptyRow): InternalRow = {
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
@@ -163,24 +243,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
Alias(expression, s"Optimized($expression)2")() :: Nil),
expression)
- val unsafeRow = plan(inputRow)
- val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
-
- if (expected == null) {
- if (!unsafeRow.isNullAt(0)) {
- val expectedRow = InternalRow(expected, expected)
- fail("Incorrect evaluation in unsafe mode: " +
- s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
- }
- } else {
- val lit = InternalRow(expected, expected)
- val expectedRow =
- UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit)
- if (unsafeRow != expectedRow) {
- fail("Incorrect evaluation in unsafe mode: " +
- s"$expression, actual: $unsafeRow, expected: $expectedRow$input")
- }
- }
+ plan.initialize(0)
+ plan(inputRow)
}
protected def checkEvaluationWithOptimization(
@@ -292,7 +356,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = {
val interpret = try {
- evaluate(expr, inputRow)
+ evaluateWithoutCodegen(expr, inputRow)
} catch {
case e: Exception => fail(s"Exception evaluating $expr", e)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
index 64b65e2070ed6..7c7c4cccee253 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression {
override def eval(input: InternalRow): Any = 10
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code =
- s"""
+ code"""
|int some_variable = 11;
|int ${ev.value} = 10;
""".stripMargin)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index a0bbe02f92354..00e97637eee7e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -22,11 +22,13 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase {
val json =
"""
|{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],
@@ -390,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val jsonData = """{"a": 1}"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
+ JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
InternalRow(1)
)
}
@@ -399,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val jsonData = """{"a" 1}"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
+ JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
null
)
// Other modes should still return `null`.
checkEvaluation(
- JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId),
+ JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true),
null
)
}
@@ -414,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: InternalRow(2) :: Nil
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=object, schema=array, output=array of single row") {
val input = """{"a": 1}"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: Nil
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=empty array, schema=array, output=empty array") {
val input = "[ ]"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = Nil
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=empty object, schema=array, output=array of single row with null") {
val input = "{ }"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(null) :: Nil
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=array of single object, schema=struct, output=single row") {
val input = """[{"a": 1}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(1)
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=array, schema=struct, output=null") {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=empty array, schema=struct, output=null") {
val input = """[]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json - input=empty object, schema=struct, output=single row with null") {
val input = """{ }"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(null)
- checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
+ checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
}
test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId),
+ JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true),
null
)
}
@@ -477,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-20549: from_json bad UTF-8") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal(badJson), gmtId),
+ JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true),
null)
}
@@ -489,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
c.set(2016, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 123)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId),
+ JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true),
InternalRow(c.getTimeInMillis * 1000L)
)
// The result doesn't change because the json string includes timezone string ("Z" here),
// which means the string represents the timestamp string in the timezone regardless of
// the timeZoneId parameter.
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")),
+ JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true),
InternalRow(c.getTimeInMillis * 1000L)
)
@@ -510,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
schema,
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"),
Literal(jsonData2),
- Option(tz.getID)),
+ Option(tz.getID),
+ true),
InternalRow(c.getTimeInMillis * 1000L)
)
checkEvaluation(
@@ -519,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss",
DateTimeUtils.TIMEZONE_OPTION -> tz.getID),
Literal(jsonData2),
- gmtId),
+ gmtId,
+ true),
InternalRow(c.getTimeInMillis * 1000L)
)
}
@@ -528,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-19543: from_json empty input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
- JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId),
+ JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true),
null
)
}
@@ -680,4 +684,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}
}
+
+ test("from_json missing fields") {
+ for (forceJsonNullableSchema <- Seq(false, true)) {
+ val input =
+ """{
+ | "a": 1,
+ | "c": "foo"
+ |}
+ |""".stripMargin
+ val jsonSchema = new StructType()
+ .add("a", LongType, nullable = false)
+ .add("b", StringType, nullable = false)
+ .add("c", StringType, nullable = false)
+ val output = InternalRow(1L, null, UTF8String.fromString("foo"))
+ val expr = JsonToStructs(
+ jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema)
+ checkEvaluation(expr, output)
+ val schema = expr.dataType
+ val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
+ assert(schemaToCompare == schema)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala
new file mode 100644
index 0000000000000..4d69dc32ace82
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("mask") {
+ checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####")
+ checkEvaluation(
+ new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")),
+ "llll-UUUU-####-####")
+ checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")),
+ "llll-UUUU-nnnn-nnnn")
+ checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn")
+ checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn")
+ checkEvaluation(new Mask(Literal(null, StringType)), null)
+ checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####")
+ checkEvaluation(new Mask(
+ Literal("abcd-EFGH-8765-4321"),
+ Literal(null, StringType),
+ Literal(null, StringType),
+ Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn")
+ checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")),
+ "xxxx-UUUU-nnnn-nnnn")
+ checkEvaluation(new Mask(Literal("")), "")
+ checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn")
+ checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn")
+ // scalastyle:off nonascii
+ checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200")
+ checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")),
+ "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋")
+ // scalastyle:on nonascii
+ intercept[AnalysisException] {
+ checkEvaluation(new Mask(Literal(""), Literal(1)), "")
+ }
+ }
+
+ test("mask_first_n") {
+ checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"),
+ "lU#l-UFGH-8765")
+ checkEvaluation(new MaskFirstN(
+ Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")),
+ "llll-UFGH-8765-4321")
+ checkEvaluation(
+ new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")),
+ "llll-UFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")),
+ "xxxx-UFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)),
+ "xxxx-XFGH-8765-4321")
+ intercept[AnalysisException] {
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "")
+ }
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(Literal(null, StringType)), null)
+ checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null),
+ "llll-EFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(
+ Literal("abcd-EFGH-8765-4321"),
+ Literal(null, IntegerType),
+ Literal(null, StringType),
+ Literal(null, StringType),
+ Literal(null, StringType)), "xxxx-EFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")),
+ "xxxx-UFGH-8765-4321")
+ checkEvaluation(new MaskFirstN(Literal("")), "")
+ checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")),
+ "xxxx-EFGH-8765-4321")
+ checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""),
+ "xxxx-XXXX-nnnn-nnnn")
+ checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""),
+ "abcd-EFGH-8765-4321")
+ // scalastyle:off nonascii
+ checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U")
+ checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)),
+ "あ, 𠀋, Xxxxo World")
+ // scalastyle:on nonascii
+ }
+
+ test("mask_last_n") {
+ checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"),
+ "abcd-EFGU-lU#l")
+ checkEvaluation(new MaskLastN(
+ Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")),
+ "abcd-EFGU-####")
+ checkEvaluation(
+ new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")),
+ "abcd-EFGU-nnnn")
+ checkEvaluation(
+ new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")),
+ "abcd-EFGU-nnnn")
+ checkEvaluation(
+ new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)),
+ "abcd-EFGX-nnnn")
+ intercept[AnalysisException] {
+ checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "")
+ }
+ checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn")
+ checkEvaluation(new MaskLastN(Literal(null, StringType)), null)
+ checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null),
+ "abcd-EFGH-8765-nnnn")
+ checkEvaluation(new MaskLastN(
+ Literal("abcd-EFGH-8765-4321"),
+ Literal(null, IntegerType),
+ Literal(null, StringType),
+ Literal(null, StringType),
+ Literal(null, StringType)), "abcd-EFGH-8765-nnnn")
+ checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")),
+ "abcd-EFUU-nnnn-nnnn")
+ checkEvaluation(new MaskLastN(Literal("")), "")
+ checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")),
+ "abcx-XXXX-nnnn-nnnn")
+ checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""),
+ "xxxx-XXXX-nnnn-nnnn")
+ checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""),
+ "abcd-EFGH-8765-4321")
+ // scalastyle:off nonascii
+ checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200")
+ checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)),
+ "あ, 𠀋, Hello Xxxxx あ 𠀋")
+ // scalastyle:on nonascii
+ }
+
+ test("mask_show_first_n") {
+ checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"),
+ "abcd-EUUU-####-lU#l")
+ checkEvaluation(new MaskShowFirstN(
+ Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")),
+ "abcd-EUUU-####-####")
+ checkEvaluation(
+ new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")),
+ "abcd-EUUU-nnnn-nnnn")
+ checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")),
+ "abcd-EUUU-nnnn-nnnn")
+ checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)),
+ "abcd-EXXX-nnnn-nnnn")
+ intercept[AnalysisException] {
+ checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "")
+ }
+ checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn")
+ checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null)
+ checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null),
+ "abcd-UUUU-nnnn-nnnn")
+ checkEvaluation(new MaskShowFirstN(
+ Literal("abcd-EFGH-8765-4321"),
+ Literal(null, IntegerType),
+ Literal(null, StringType),
+ Literal(null, StringType),
+ Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn")
+ checkEvaluation(
+ new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")),
+ "abcd-EUUU-nnnn-nnnn")
+ checkEvaluation(new MaskShowFirstN(Literal("")), "")
+ checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")),
+ "abcd-XXXX-nnnn-nnnn")
+ checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""),
+ "abcd-EFGH-8765-4321")
+ checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""),
+ "xxxx-XXXX-nnnn-nnnn")
+ // scalastyle:off nonascii
+ checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200")
+ checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)),
+ "あ, 𠀋, Hellx Xxxxx")
+ // scalastyle:on nonascii
+ }
+
+ test("mask_show_last_n") {
+ checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"),
+ "lU#l-UUUH-8765")
+ checkEvaluation(new MaskShowLastN(
+ Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")),
+ "llll-UUUU-###5-4321")
+ checkEvaluation(
+ new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")),
+ "llll-UUUU-nnn5-4321")
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")),
+ "xxxx-UUUU-nnn5-4321")
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)),
+ "xxxx-XXXX-nnn5-4321")
+ intercept[AnalysisException] {
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "")
+ }
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321")
+ checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null)
+ checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null),
+ "llll-UUUU-nnnn-4321")
+ checkEvaluation(new MaskShowLastN(
+ Literal("abcd-EFGH-8765-4321"),
+ Literal(null, IntegerType),
+ Literal(null, StringType),
+ Literal(null, StringType),
+ Literal(null, StringType)), "xxxx-XXXX-nnnn-4321")
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")),
+ "xxxx-UUUU-nnn5-4321")
+ checkEvaluation(new MaskShowLastN(Literal("")), "")
+ checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")),
+ "xxxx-XXXX-nnnn-4321")
+ checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""),
+ "abcd-EFGH-8765-4321")
+ checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""),
+ "xxxx-XXXX-nnnn-nnnn")
+ // scalastyle:off nonascii
+ checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U")
+ checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)),
+ "あ, 𠀋, Xello World")
+ // scalastyle:on nonascii
+ }
+
+ test("mask_hash") {
+ checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776")
+ checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e")
+ checkEvaluation(MaskHash(Literal(null, StringType)), null)
+ // scalastyle:off nonascii
+ checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711")
+ // scalastyle:on nonascii
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 39e0060d41dd4..3a094079380fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkNaNWithoutCodegen(
expression: Expression,
inputRow: InternalRow = EmptyRow): Unit = {
- val actual = try evaluate(expression, inputRow) catch {
+ val actual = try evaluateWithoutCodegen(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (!actual.asInstanceOf[Double].isNaN) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
index facc863081303..b6c269348b002 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import java.io.PrintStream
+
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
@@ -40,7 +44,49 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("uuid") {
- checkEvaluation(Length(Uuid()), 36)
- assert(evaluate(Uuid()) !== evaluate(Uuid()))
+ checkEvaluation(Length(Uuid(Some(0))), 36)
+ val r = new Random()
+ val seed1 = Some(r.nextLong())
+ assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1)))
+ assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) ===
+ evaluateWithGeneratedMutableProjection(Uuid(seed1)))
+ assert(evaluateWithUnsafeProjection(Uuid(seed1)) ===
+ evaluateWithUnsafeProjection(Uuid(seed1)))
+
+ val seed2 = Some(r.nextLong())
+ assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2)))
+ assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !==
+ evaluateWithGeneratedMutableProjection(Uuid(seed2)))
+ assert(evaluateWithUnsafeProjection(Uuid(seed1)) !==
+ evaluateWithUnsafeProjection(Uuid(seed2)))
+
+ val uuid = Uuid(seed1)
+ assert(uuid.fastEquals(uuid))
+ assert(!uuid.fastEquals(Uuid(seed1)))
+ assert(!uuid.fastEquals(uuid.freshCopy()))
+ assert(!uuid.fastEquals(Uuid(seed2)))
+ }
+
+ test("PrintToStderr") {
+ val inputExpr = Literal(1)
+ val systemErr = System.err
+
+ val (outputEval, outputCodegen) = try {
+ val errorStream = new java.io.ByteArrayOutputStream()
+ System.setErr(new PrintStream(errorStream))
+ // check without codegen
+ checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1)
+ val outputEval = errorStream.toString
+ errorStream.reset()
+ // check with codegen
+ checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1)
+ val outputCodegen = errorStream.toString
+ (outputEval, outputCodegen)
+ } finally {
+ System.setErr(systemErr)
+ }
+
+ assert(outputCodegen.contains(s"Result of $inputExpr is 1"))
+ assert(outputEval.contains(s"Result of $inputExpr is 1"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index cc6c15cb2c909..424c3a4696077 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("AssertNotNUll") {
val ex = intercept[RuntimeException] {
- evaluate(AssertNotNull(Literal(null), Seq.empty[String]))
+ evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String]))
}.getMessage
assert(ex.contains("Null value appeared in non-nullable field"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 3edcc02f15264..20d568c44258f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -17,13 +17,52 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
-import org.apache.spark.sql.types.{IntegerType, ObjectType}
+import java.sql.{Date, Timestamp}
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.Random
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection}
+import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.objects._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+
+class InvokeTargetClass extends Serializable {
+ def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
+ def filterPrimitiveInt(e: Int): Boolean = e > 0
+ def binOp(e1: Int, e2: Double): Double = e1 + e2
+}
+
+class InvokeTargetSubClass extends InvokeTargetClass {
+ override def binOp(e1: Int, e2: Double): Double = e1 - e2
+}
+
+// Tests for NewInstance
+class Outer extends Serializable {
+ class Inner(val value: Int) {
+ override def hashCode(): Int = super.hashCode()
+ override def equals(other: Any): Boolean = {
+ if (other.isInstanceOf[Inner]) {
+ value == other.asInstanceOf[Inner].value
+ } else {
+ false
+ }
+ }
+ }
+}
class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -41,7 +80,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
val structExpected = new GenericArrayData(
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
- checkEvalutionWithUnsafeProjection(
+ checkEvaluationWithUnsafeProjection(
structEncoder.serializer.head, structExpected, structInputRow)
// test UnsafeArray-backed data
@@ -49,7 +88,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
val arrayExpected = new GenericArrayData(
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
- checkEvalutionWithUnsafeProjection(
+ checkEvaluationWithUnsafeProjection(
arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
// test UnsafeMap-backed data
@@ -63,7 +102,507 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new ArrayBasedMapData(
new GenericArrayData(Array(3, 4)),
new GenericArrayData(Array(300, 400)))))
- checkEvalutionWithUnsafeProjection(
+ checkEvaluationWithUnsafeProjection(
mapEncoder.serializer.head, mapExpected, mapInputRow)
}
+
+ test("SPARK-23582: StaticInvoke should support interpreted execution") {
+ Seq((classOf[java.lang.Boolean], "true", true),
+ (classOf[java.lang.Byte], "1", 1.toByte),
+ (classOf[java.lang.Short], "257", 257.toShort),
+ (classOf[java.lang.Integer], "12345", 12345),
+ (classOf[java.lang.Long], "12345678", 12345678.toLong),
+ (classOf[java.lang.Float], "12.34", 12.34.toFloat),
+ (classOf[java.lang.Double], "1.2345678", 1.2345678)
+ ).foreach { case (cls, arg, expected) =>
+ checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf",
+ Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))),
+ expected, InternalRow.fromSeq(Seq(arg)))
+ }
+
+ // Return null when null argument is passed with propagateNull = true
+ val stringCls = classOf[java.lang.String]
+ checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
+ Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true),
+ null, InternalRow.fromSeq(Seq(null)))
+ checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
+ Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false),
+ "null", InternalRow.fromSeq(Seq(null)))
+
+ // test no argument
+ val clCls = classOf[java.lang.ClassLoader]
+ checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil),
+ ClassLoader.getSystemClassLoader, InternalRow.empty)
+ // test more than one argument
+ val intCls = classOf[java.lang.Integer]
+ checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare",
+ Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))),
+ 0, InternalRow.fromSeq(Seq(7, 7)))
+
+ Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]),
+ new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))),
+ (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]),
+ new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))),
+ (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]),
+ "abc", UTF8String.fromString("abc")),
+ (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]),
+ BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))),
+ (Decimal.getClass, DecimalType.SYSTEM_DEFAULT,
+ "apply", ObjectType(classOf[java.math.BigInteger]),
+ new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))),
+ (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]),
+ Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))),
+ (classOf[UnsafeArrayData], ArrayType(IntegerType, false),
+ "fromPrimitiveArray", ObjectType(classOf[Array[Int]]),
+ Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))),
+ (DateTimeUtils.getClass, ObjectType(classOf[Date]),
+ "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777,
+ DateTimeUtils.toJavaDate(77777)),
+ (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]),
+ "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]),
+ 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888))
+ ).foreach { case (cls, dataType, methodName, argType, arg, expected) =>
+ checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName,
+ Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg)))
+ }
+ }
+
+ test("SPARK-23583: Invoke should support interpreted execution") {
+ val targetObject = new InvokeTargetClass
+ val funcClass = classOf[InvokeTargetClass]
+ val funcObj = Literal.create(targetObject, ObjectType(funcClass))
+ val targetSubObject = new InvokeTargetSubClass
+ val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass]))
+ val funcNullObj = Literal.create(null, ObjectType(funcClass))
+
+ val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true))
+ val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false))
+ val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt),
+ false, InternalRow.fromSeq(Seq(-1)))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ null, InternalRow.fromSeq(Seq(null)))
+
+ checkObjectExprEvaluation(
+ Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt),
+ null, InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25))
+
+ checkObjectExprEvaluation(
+ Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
+ }
+
+ test("SPARK-23593: InitializeJavaBean should support interpreted execution") {
+ val list = new java.util.LinkedList[Int]()
+ list.add(1)
+
+ val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]),
+ Map("add" -> Literal(1)))
+ checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq()))
+
+ val initializeWithNonexistingMethod = InitializeJavaBean(
+ Literal.fromObject(new java.util.LinkedList[Int]),
+ Map("nonexisting" -> Literal(1)))
+ checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
+ """A method named "nonexisting" is not declared in any enclosing class """ +
+ "nor any supertype")
+
+ val initializeWithWrongParamType = InitializeJavaBean(
+ Literal.fromObject(new TestBean),
+ Map("setX" -> Literal("1")))
+ intercept[Exception] {
+ evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq()))
+ }.getMessage.contains(
+ """A method named "setX" is not declared in any enclosing class """ +
+ "nor any supertype")
+ }
+
+ test("InitializeJavaBean doesn't call setters if input in null") {
+ val initializeBean = InitializeJavaBean(
+ Literal.fromObject(new TestBean),
+ Map("setNonPrimitive" -> Literal(null)))
+ evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq()))
+ evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq()))
+
+ val initializeBean2 = InitializeJavaBean(
+ Literal.fromObject(new TestBean),
+ Map("setNonPrimitive" -> Literal("string")))
+ evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq()))
+ evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq()))
+ }
+
+ test("SPARK-23585: UnwrapOption should support interpreted execution") {
+ val cls = classOf[Option[Int]]
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val unwrapObject = UnwrapOption(IntegerType, inputObject)
+ Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) =>
+ checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input)))
+ }
+ }
+
+ test("SPARK-23586: WrapOption should support interpreted execution") {
+ val cls = ObjectType(classOf[java.lang.Integer])
+ val inputObject = BoundReference(0, cls, nullable = true)
+ val wrapObject = WrapOption(inputObject, cls)
+ Seq((1, Some(1)), (null, None)).foreach { case (input, expected) =>
+ checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input)))
+ }
+ }
+
+ test("SPARK-23590: CreateExternalRow should support interpreted execution") {
+ val schema = new StructType().add("a", IntegerType).add("b", StringType)
+ val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema)
+ checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq()))
+ }
+
+ // by scala values instead of catalyst values.
+ private def checkObjectExprEvaluation(
+ expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
+ val serializer = new JavaSerializer(new SparkConf()).newInstance
+ val resolver = ResolveTimeZone(new SQLConf)
+ val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+ checkEvaluationWithoutCodegen(expr, expected, inputRow)
+ checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow)
+ if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
+ checkEvaluationWithUnsafeProjection(
+ expr,
+ expected,
+ inputRow)
+ }
+ checkEvaluationWithOptimization(expr, expected, inputRow)
+ }
+
+ test("SPARK-23594 GetExternalRowField should support interpreted execution") {
+ val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
+ val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
+ Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) =>
+ checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
+ }
+
+ // If an input row or a field are null, a runtime exception will be thrown
+ checkExceptionInExpression[RuntimeException](
+ getRowField,
+ InternalRow.fromSeq(Seq(null)),
+ "The input external row cannot be null.")
+ checkExceptionInExpression[RuntimeException](
+ getRowField,
+ InternalRow.fromSeq(Seq(Row(null))),
+ "The 0th field 'c0' of input row cannot be null.")
+ }
+
+ test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") {
+ val cls = ObjectType(classOf[java.lang.Integer])
+ val inputObject = BoundReference(0, cls, nullable = true)
+ val conf = new SparkConf()
+ Seq(true, false).foreach { useKryo =>
+ val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf)
+ val expected = serializer.newInstance().serialize(new Integer(1)).array()
+ val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo)
+ checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1)))
+ checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
+ }
+ }
+
+ test("SPARK-23587: MapObjects should support interpreted execution") {
+ def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = {
+ val function = (lambda: Expression) => Add(lambda, Literal(1))
+ val elementType = IntegerType
+ val expected = Seq(2, 3, 4)
+
+ val inputObject = BoundReference(0, inputType, nullable = true)
+ val optClass = Option(collectionCls)
+ val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
+ val row = InternalRow.fromSeq(Seq(collection))
+ val result = mapObj.eval(row)
+
+ collectionCls match {
+ case null =>
+ assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
+ case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
+ assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
+ case s if classOf[Seq[_]].isAssignableFrom(s) =>
+ assert(result.asInstanceOf[Seq[_]].toSeq == expected)
+ case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
+ assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
+ }
+ }
+
+ val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
+ classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
+ classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
+ classOf[java.util.Stack[Int]], null)
+
+ val list = new java.util.ArrayList[Int]()
+ list.add(1)
+ list.add(2)
+ list.add(3)
+ val arrayData = new GenericArrayData(Array(1, 2, 3))
+ val vector = new java.util.Vector[Int]()
+ vector.add(1)
+ vector.add(2)
+ vector.add(3)
+ val stack = new java.util.Stack[Int]()
+ stack.add(1)
+ stack.add(2)
+ stack.add(3)
+
+ Seq(
+ (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
+ (Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
+ (Seq(1, 2, 3), ObjectType(classOf[Object])),
+ (Array(1, 2, 3), ObjectType(classOf[Object])),
+ (list, ObjectType(classOf[java.util.List[Int]])),
+ (vector, ObjectType(classOf[java.util.Vector[Int]])),
+ (stack, ObjectType(classOf[java.util.Stack[Int]])),
+ (arrayData, ArrayType(IntegerType))
+ ).foreach { case (collection, inputType) =>
+ customCollectionClasses.foreach(testMapObjects(collection, _, inputType))
+
+ // Unsupported custom collection class
+ val errMsg = intercept[RuntimeException] {
+ testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType)
+ }.getMessage()
+ assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " +
+ "as resulting collection."))
+ }
+ }
+
+ test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") {
+ val cls = classOf[java.lang.Integer]
+ val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true)
+ val conf = new SparkConf()
+ Seq(true, false).foreach { useKryo =>
+ val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf)
+ val input = serializer.newInstance().serialize(new Integer(1)).array()
+ val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo)
+ checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input)))
+ checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
+ }
+ }
+
+ test("SPARK-23584 NewInstance should support interpreted execution") {
+ // Normal case test
+ val newInst1 = NewInstance(
+ cls = classOf[GenericArrayData],
+ arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
+ propagateNull = false,
+ dataType = ArrayType(IntegerType),
+ outerPointer = None)
+ checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3)))
+
+ // Inner class case test
+ val outerObj = new Outer()
+ val newInst2 = NewInstance(
+ cls = classOf[outerObj.Inner],
+ arguments = Literal(1) :: Nil,
+ propagateNull = false,
+ dataType = ObjectType(classOf[outerObj.Inner]),
+ outerPointer = Some(() => outerObj))
+ checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))
+ }
+
+ test("LambdaVariable should support interpreted execution") {
+ def genSchema(dt: DataType): Seq[StructType] = {
+ Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
+ StructType(StructField("col_1", dt, nullable = true) :: Nil))
+ }
+
+ val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
+ DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
+ CalendarIntervalType, new ExamplePointUDT())
+ val arrayTypes = elementTypes.flatMap { elementType =>
+ Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
+ }
+ val mapTypes = elementTypes.flatMap { elementType =>
+ Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true))
+ }
+ val structTypes = elementTypes.flatMap { elementType =>
+ Seq(StructType(StructField("col1", elementType, false) :: Nil),
+ StructType(StructField("col1", elementType, true) :: Nil))
+ }
+
+ val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes
+ val random = new Random(100)
+ testTypes.foreach { dt =>
+ genSchema(dt).map { schema =>
+ val row = RandomDataGenerator.randomRow(random, schema)
+ val rowConverter = RowEncoder(schema)
+ val internalRow = rowConverter.toRow(row)
+ val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable)
+ checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow)
+ }
+ }
+ }
+
+ implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]()
+
+ test("SPARK-23588 CatalystToExternalMap should support interpreted execution") {
+ // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan
+ // with dummy input, resolve the plan by the analyzer, and replace the dummy input
+ // with a literal for tests.
+ val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer)
+ val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType)))
+ val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan)
+
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head
+
+ // Replaces the dummy input with a literal for tests here
+ val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3")
+ val deserializer = toMapExpr.copy(inputData = Literal.create(data))
+ checkObjectExprEvaluation(deserializer, expected = data)
+ }
+
+ test("SPARK-23595 ValidateExternalType should support interpreted execution") {
+ val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
+ Seq(
+ (true, BooleanType),
+ (2.toByte, ByteType),
+ (5.toShort, ShortType),
+ (23, IntegerType),
+ (61L, LongType),
+ (1.0f, FloatType),
+ (10.0, DoubleType),
+ ("abcd".getBytes, BinaryType),
+ ("abcd", StringType),
+ (BigDecimal.valueOf(10), DecimalType.IntDecimal),
+ (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType),
+ (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
+ (Array(3, 2, 1), ArrayType(IntegerType))
+ ).foreach { case (input, dt) =>
+ val validateType = ValidateExternalType(
+ GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
+ checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
+ }
+
+ checkExceptionInExpression[RuntimeException](
+ ValidateExternalType(
+ GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType),
+ InternalRow.fromSeq(Seq(Row(1))),
+ "java.lang.Integer is not a valid external type for schema of double")
+ }
+
+ private def javaMapSerializerFor(
+ keyClazz: Class[_],
+ valueClazz: Class[_])(inputObject: Expression): Expression = {
+
+ def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match {
+ case c if c == classOf[java.lang.Integer] =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case c if c == classOf[java.lang.String] =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil,
+ returnNullable = false)
+ }
+
+ ExternalMapToCatalyst(
+ inputObject,
+ ObjectType(keyClazz),
+ kvSerializerFor(_, keyClazz),
+ keyNullable = true,
+ ObjectType(valueClazz),
+ kvSerializerFor(_, valueClazz),
+ valueNullable = true
+ )
+ }
+
+ private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = {
+ import org.apache.spark.sql.catalyst.ScalaReflection._
+
+ val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression =
+ localTypeOf[V].dealias match {
+ case t if t <:< localTypeOf[java.lang.Integer] =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case t if t <:< localTypeOf[String] =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil,
+ returnNullable = false)
+ case _ =>
+ inputObject
+ }
+
+ ExternalMapToCatalyst(
+ inputObject,
+ dataTypeFor[T],
+ kvSerializerFor[T],
+ keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive,
+ dataTypeFor[U],
+ kvSerializerFor[U],
+ valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive
+ )
+ }
+
+ test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") {
+ // Simple test
+ val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3")
+ val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() {
+ {
+ put(0, "v0")
+ put(1, "v1")
+ put(2, null)
+ put(3, "v3")
+ }
+ }
+ val expected = CatalystTypeConverters.convertToCatalyst(scalaMap)
+
+ // Java Map
+ val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])(
+ Literal.fromObject(javaMap))
+ checkEvaluation(serializer1, expected)
+
+ // Scala Map
+ val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap))
+ checkEvaluation(serializer2, expected)
+
+ // NULL key test
+ val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String](
+ null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1")
+ val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() {
+ {
+ put(null, "v0")
+ put(1, "v1")
+ }
+ }
+
+ // Java Map
+ val serializer3 =
+ javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])(
+ Literal.fromObject(javaMapHasNullKey))
+ checkExceptionInExpression[RuntimeException](
+ serializer3, EmptyRow, "Cannot use null as map key!")
+
+ // Scala Map
+ val serializer4 = scalaMapSerializerFor[java.lang.Integer, String](
+ Literal.fromObject(scalaMapHasNullKey))
+
+ checkExceptionInExpression[RuntimeException](
+ serializer4, EmptyRow, "Cannot use null as map key!")
+ }
+}
+
+class TestBean extends Serializable {
+ private var x: Int = 0
+
+ def setX(i: Int): Unit = x = i
+
+ def setNonPrimitive(i: AnyRef): Unit =
+ assert(i != null, "this setter should not be called with null.")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 8a8f8e10225fa..ac76b17ef4761 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -442,4 +442,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
+
+ test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") {
+ checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false)
+ checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false)
+ checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false)
+ checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false)
+ }
+
+ test("Interpreted Predicate should initialize nondeterministic expressions") {
+ val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0)))
+ interpreted.initialize(0)
+ assert(interpreted.eval(new UnsafeRow()))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
index 2a0a42c65b086..d532dc4f77198 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala
@@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// invalid escaping
val invalidEscape = intercept[AnalysisException] {
- evaluate("""a""" like """\a""")
+ evaluateWithoutCodegen("""a""" like """\a""")
}
assert(invalidEscape.getMessage.contains("pattern"))
val endEscape = intercept[AnalysisException] {
- evaluate("""a""" like """a\""")
+ evaluateWithoutCodegen("""a""" like """a\""")
}
assert(endEscape.getMessage.contains("pattern"))
@@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkLiteralRow("abc" rlike _, "^bc", false)
intercept[java.util.regex.PatternSyntaxException] {
- evaluate("abbbbc" rlike "**")
+ evaluateWithoutCodegen("abbbbc" rlike "**")
}
intercept[java.util.regex.PatternSyntaxException] {
val regex = 'a.string.at(0)
- evaluate("abbbbc" rlike regex, create_row("**"))
+ evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 10e3ffd0dff97..e083ae0089244 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(e1.getMessage.contains("Failed to execute user defined function"))
val e2 = intercept[SparkException] {
- checkEvalutionWithUnsafeProjection(udf, null)
+ checkEvaluationWithUnsafeProjection(udf, null)
}
assert(e2.getMessage.contains("Failed to execute user defined function"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala
new file mode 100644
index 0000000000000..cc2e2a993d629
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.sql.{Date, Timestamp}
+import java.util.TimeZone
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
+
+class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("SortPrefix") {
+ val b1 = Literal.create(false, BooleanType)
+ val b2 = Literal.create(true, BooleanType)
+ val i1 = Literal.create(20132983, IntegerType)
+ val i2 = Literal.create(-20132983, IntegerType)
+ val l1 = Literal.create(20132983, LongType)
+ val l2 = Literal.create(-20132983, LongType)
+ val millis = 1524954911000L;
+ // Explicitly choose a time zone, since Date objects can create different values depending on
+ // local time zone of the machine on which the test is running
+ val oldDefaultTZ = TimeZone.getDefault
+ val d1 = try {
+ TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
+ Literal.create(new java.sql.Date(millis), DateType)
+ } finally {
+ TimeZone.setDefault(oldDefaultTZ)
+ }
+ val t1 = Literal.create(new Timestamp(millis), TimestampType)
+ val f1 = Literal.create(0.7788229f, FloatType)
+ val f2 = Literal.create(-0.7788229f, FloatType)
+ val db1 = Literal.create(0.7788229d, DoubleType)
+ val db2 = Literal.create(-0.7788229d, DoubleType)
+ val s1 = Literal.create("T", StringType)
+ val s2 = Literal.create("This is longer than 8 characters", StringType)
+ val bin1 = Literal.create(Array[Byte](12), BinaryType)
+ val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]),
+ BinaryType)
+ val dec1 = Literal(Decimal(20132983L, 10, 2))
+ val dec2 = Literal(Decimal(20132983L, 19, 2))
+ val dec3 = Literal(Decimal(20132983L, 21, 2))
+ val list1 = Literal(List(1, 2), ArrayType(IntegerType))
+ val nullVal = Literal.create(null, IntegerType)
+
+ checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
+ checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L)
+ checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L)
+ checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L)
+ checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L)
+ checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L)
+ // For some reason, the Literal.create code gives us the number of days since the epoch
+ checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L)
+ checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000)
+ checkEvaluation(SortPrefix(SortOrder(f1, Ascending)),
+ DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble))
+ checkEvaluation(SortPrefix(SortOrder(f2, Ascending)),
+ DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble))
+ checkEvaluation(SortPrefix(SortOrder(db1, Ascending)),
+ DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double]))
+ checkEvaluation(SortPrefix(SortOrder(db2, Ascending)),
+ DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double]))
+ checkEvaluation(SortPrefix(SortOrder(s1, Ascending)),
+ StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String]))
+ checkEvaluation(SortPrefix(SortOrder(s2, Ascending)),
+ StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String]))
+ checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)),
+ BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]]))
+ checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)),
+ BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]]))
+ checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L)
+ checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L)
+ checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)),
+ DoublePrefixComparator.computePrefix(201329.83d))
+ checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L)
+ checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 97ddbeba2c5ca..aa334e040d5fc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("REVERSE") {
val s = 'a.string.at(0)
val row1 = create_row("abccc")
- checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
- checkEvaluation(StringReverse(s), "cccba", row1)
- checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
+ checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
+ checkEvaluation(Reverse(s), "cccba", row1)
+ checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
}
test("SPACE") {
@@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"15,159,339,180,002,773.2778")
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false)
+
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123")
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4")
+ checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4")
+ checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4")
+ checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4")
+ checkEvaluation(FormatNumber(Literal(12831273.23481d),
+ Literal("###,###,###,###,###.###")), "12,831,273.235")
+ checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274")
+ checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")),
+ "123,123,324,123")
+ checkEvaluation(
+ FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)),
+ Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778")
+ checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null)
+ assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false)
+
+ checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")),
+ "12,332")
+ checkEvaluation(FormatNumber(
+ Literal.create(null, IntegerType), Literal.create(null, StringType)), null)
+ checkEvaluation(FormatNumber(
+ Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null)
}
test("find in set") {
@@ -756,7 +780,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// exceptional cases
intercept[java.util.regex.PatternSyntaxException] {
- evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
+ evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
Literal("QUERY"), Literal("???"))))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
index d6c8fcf291842..351d4d0c2eac9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
@@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
test("time window is unevaluable") {
intercept[UnsupportedOperationException] {
- evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second"))
+ evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index cf3cbe270753e..5a646d9a850ac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -24,19 +24,32 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, LongType, _}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
-class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
+class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase {
private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
- test("basic conversion with only primitive types") {
- val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
- val converter = UnsafeProjection.create(fieldTypes)
+ private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = {
+ val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN)
+ for (fallbackMode <- modes) {
+ test(s"$name with $fallbackMode") {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) {
+ f
+ }
+ }
+ }
+ }
+ testBothCodegenAndInterpreted("basic conversion with only primitive types") {
+ val factory = UnsafeProjection
+ val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
row.setLong(1, 1)
@@ -71,9 +84,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow2.getInt(2) === 2)
}
- test("basic conversion with primitive, string and binary types") {
+ testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
@@ -90,9 +104,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8))
}
- test("basic conversion with primitive, string, date and timestamp types") {
+ testBothCodegenAndInterpreted(
+ "basic conversion with primitive, string, date and timestamp types") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
@@ -119,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
(Timestamp.valueOf("2015-06-22 08:10:25"))
}
- test("null handling") {
+ testBothCodegenAndInterpreted("null handling") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
NullType,
BooleanType,
@@ -135,7 +152,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
DecimalType.SYSTEM_DEFAULT
// ArrayType(IntegerType)
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val rowWithAllNullColumns: InternalRow = {
val r = new SpecificInternalRow(fieldTypes)
@@ -240,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
- test("NaN canonicalization") {
+ testBothCodegenAndInterpreted("NaN canonicalization") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
val row1 = new SpecificInternalRow(fieldTypes)
@@ -251,17 +269,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
- test("basic conversion with struct type") {
+ testBothCodegenAndInterpreted("basic conversion with struct type") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
new StructType().add("i", IntegerType),
new StructType().add("nest", new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(1))
@@ -317,12 +336,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
- test("basic conversion with array type") {
+ testBothCodegenAndInterpreted("basic conversion with array type") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
ArrayType(IntegerType),
ArrayType(ArrayType(IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, createArray(1, 2))
@@ -347,12 +367,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
}
- test("basic conversion with map type") {
+ testBothCodegenAndInterpreted("basic conversion with map type") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
MapType(IntegerType, IntegerType),
MapType(IntegerType, MapType(IntegerType, IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val map1 = createMap(1, 2)(3, 4)
@@ -393,12 +414,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
}
- test("basic conversion with struct and array") {
+ testBothCodegenAndInterpreted("basic conversion with struct and array") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
new StructType().add("arr", ArrayType(IntegerType)),
ArrayType(new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(createArray(1)))
@@ -432,12 +454,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
- test("basic conversion with struct and map") {
+ testBothCodegenAndInterpreted("basic conversion with struct and map") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
new StructType().add("map", MapType(IntegerType, IntegerType)),
MapType(IntegerType, new StructType().add("l", LongType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, InternalRow(createMap(1)(2)))
@@ -478,12 +501,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
- test("basic conversion with array and map") {
+ testBothCodegenAndInterpreted("basic conversion with array and map") {
+ val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
ArrayType(MapType(IntegerType, IntegerType)),
MapType(IntegerType, ArrayType(IntegerType))
)
- val converter = UnsafeProjection.create(fieldTypes)
+ val converter = factory.create(fieldTypes)
val row = new GenericInternalRow(fieldTypes.length)
row.update(0, createArray(createMap(1)(2)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
new file mode 100644
index 0000000000000..d2c6420eadb20
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class CodeBlockSuite extends SparkFunSuite {
+
+ test("Block interpolates string and ExprValue inputs") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val stringLiteral = "false"
+ val code = code"boolean $isNull = $stringLiteral;"
+ assert(code.toString == "boolean expr1_isNull = false;")
+ }
+
+ test("Literals are folded into string code parts instead of block inputs") {
+ val value = JavaCode.variable("expr1", IntegerType)
+ val intLiteral = 1
+ val code = code"int $value = $intLiteral;"
+ assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
+ }
+
+ test("Block.stripMargin") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val value = JavaCode.variable("expr1", IntegerType)
+ val code1 =
+ code"""
+ |boolean $isNull = false;
+ |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin
+ val expected =
+ s"""
+ |boolean expr1_isNull = false;
+ |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim
+ assert(code1.toString == expected)
+
+ val code2 =
+ code"""
+ >boolean $isNull = false;
+ >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>')
+ assert(code2.toString == expected)
+ }
+
+ test("Block can capture input expr values") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val value = JavaCode.variable("expr1", IntegerType)
+ val code =
+ code"""
+ |boolean $isNull = false;
+ |int $value = -1;
+ """.stripMargin
+ val exprValues = code.exprValues
+ assert(exprValues.size == 2)
+ assert(exprValues === Set(value, isNull))
+ }
+
+ test("concatenate blocks") {
+ val isNull1 = JavaCode.isNullVariable("expr1_isNull")
+ val value1 = JavaCode.variable("expr1", IntegerType)
+ val isNull2 = JavaCode.isNullVariable("expr2_isNull")
+ val value2 = JavaCode.variable("expr2", IntegerType)
+ val literal = JavaCode.literal("100", IntegerType)
+
+ val code =
+ code"""
+ |boolean $isNull1 = false;
+ |int $value1 = -1;""".stripMargin +
+ code"""
+ |boolean $isNull2 = true;
+ |int $value2 = $literal;""".stripMargin
+
+ val expected =
+ """
+ |boolean expr1_isNull = false;
+ |int expr1 = -1;
+ |boolean expr2_isNull = true;
+ |int expr2 = 100;""".stripMargin.trim
+
+ assert(code.toString == expected)
+
+ val exprValues = code.exprValues
+ assert(exprValues.size == 5)
+ assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
+ }
+
+ test("Throws exception when interpolating unexcepted object in code block") {
+ val obj = Tuple2(1, 1)
+ val e = intercept[IllegalArgumentException] {
+ code"$obj"
+ }
+ assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
+ }
+
+ test("replace expr values in code block") {
+ val expr = JavaCode.expression("1 + 1", IntegerType)
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+ val code =
+ code"""
+ |callFunc(int $expr) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $expr + 1;
+ |}""".stripMargin
+
+ val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+ val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
+ case _: SimpleExprValue => aliasedParam
+ case other => other
+ }
+ val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
+ val expected =
+ code"""
+ |callFunc(int $aliasedParam) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $aliasedParam + 1;
+ |}""".stripMargin
+ assert(aliasedCode.toString == expected.toString)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala
new file mode 100644
index 0000000000000..378b8bc055e34
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.BooleanType
+
+class ExprValueSuite extends SparkFunSuite {
+
+ test("TrueLiteral and FalseLiteral should be LiteralValue") {
+ val trueLit = TrueLiteral
+ val falseLit = FalseLiteral
+
+ assert(trueLit.value == "true")
+ assert(falseLit.value == "false")
+
+ assert(trueLit.isPrimitive)
+ assert(falseLit.isPrimitive)
+
+ assert(trueLit === JavaCode.literal("true", BooleanType))
+ assert(falseLit === JavaCode.literal("false", BooleanType))
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala
similarity index 51%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala
index 91e9a9f211335..1b25a4b191f86 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala
@@ -1,38 +1,42 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit.steps
-
-import org.apache.spark.deploy.k8s.MountSecretsBootstrap
-import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec
-
-/**
- * A driver configuration step for mounting user-specified secrets onto user-specified paths.
- *
- * @param bootstrap a utility actually handling mounting of the secrets.
- */
-private[spark] class DriverMountSecretsStep(
- bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep {
-
- override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = {
- val pod = bootstrap.addSecretVolumes(driverSpec.driverPod)
- val container = bootstrap.mountSecrets(driverSpec.driverContainer)
- driverSpec.copy(
- driverPod = pod,
- driverContainer = container
- )
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.unsafe.types.UTF8String
+
+class UTF8StringBuilderSuite extends SparkFunSuite {
+
+ test("basic test") {
+ val sb = new UTF8StringBuilder()
+ assert(sb.build() === UTF8String.EMPTY_UTF8)
+
+ sb.append("")
+ assert(sb.build() === UTF8String.EMPTY_UTF8)
+
+ sb.append("abcd")
+ assert(sb.build() === UTF8String.fromString("abcd"))
+
+ sb.append(UTF8String.fromString("1234"))
+ assert(sb.build() === UTF8String.fromString("abcd1234"))
+
+ // expect to grow an internal buffer
+ sb.append(UTF8String.fromString("efgijk567890"))
+ assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
index c4cde7091154b..0fec15bc42c17 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
@@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite {
assert(ret == "foo")
}
+ test("embedFailure") {
+ import org.apache.commons.io.FileUtils
+ import java.io.File
+ val secretValue = String.valueOf(Math.random)
+ val tempFile = File.createTempFile("verifyembed", ".tmp")
+ tempFile.deleteOnExit()
+ val fname = tempFile.getAbsolutePath
+
+ FileUtils.writeStringToFile(tempFile, secretValue)
+
+ val xml =
+ s"""
+ |
+ |]>
+ |&embed;
+ """.stripMargin
+ val evaled = new UDFXPathUtil().evalString(xml, "/foo")
+ assert(evaled.isEmpty)
+ }
+
test("number eval") {
var ret =
util.evalNumber("truefalseb3c1-77", "a/c[2]")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
index bfa18a0919e45..c6f6d3abb860c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
@@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
// Test error message for invalid XML document
val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) }
- assert(e1.getCause.getMessage.contains("Invalid XML document") &&
- e1.getCause.getMessage.contains("/a>"))
+ assert(e1.getCause.getCause.getMessage.contains(
+ "XML document structures must start and end within the same entity."))
+ assert(e1.getMessage.contains("/a>"))
// Test error message for invalid xpath
val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala
index 21220b38968e8..788fedb3c8e8e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala
@@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest {
val thrownException = the [AnalysisException] thrownBy {
performCartesianProductCheck(joinType)
}
- assert(thrownException.message.contains("Detected cartesian product"))
+ assert(thrownException.message.contains("Detected implicit cartesian product"))
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
index 178c4b8c270a0..e4671f0d1cce6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
@@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
InferFiltersFromConstraints,
CombineFilters,
SimplifyBinaryComparison,
- BooleanSimplification) :: Nil
+ BooleanSimplification,
+ PruneFilters) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ private def testConstraintsAfterJoin(
+ x: LogicalPlan,
+ y: LogicalPlan,
+ expectedLeft: LogicalPlan,
+ expectedRight: LogicalPlan,
+ joinType: JoinType) = {
+ val condition = Some("x.a".attr === "y.a".attr)
+ val originalQuery = x.join(y, joinType, condition).analyze
+ val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
test("filter: filter out constraints in condition") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val correctAnswer = testRelation
@@ -192,4 +206,61 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(Optimize.execute(original.analyze), correct.analyze)
}
+
+ test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
+ }
+
+ test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val condition = Some("x.a".attr === "y.a".attr)
+ val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze
+ val left = x.where(IsNotNull('a) && 'a === 2)
+ val right = y.where(IsNotNull('a) && 'a === 2)
+ val correctAnswer = left.join(right, LeftOuter, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ val condition = Some("x.a".attr === "y.a".attr)
+ val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze
+ val left = x.where(IsNotNull('a) && 'a > 5)
+ val right = y.where(IsNotNull('a) && 'a > 5)
+ val correctAnswer = left.join(right, RightOuter, condition).analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-21479: Outer join no filter push down to preserved side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(
+ x, y.where("a".attr === 1),
+ x, y.where(IsNotNull('a) && 'a === 1),
+ LeftOuter)
+ }
+
+ test("SPARK-23564: left anti join should filter out null join keys on right side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
+ }
+
+ test("SPARK-23564: left outer join should filter out null join keys on right side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
+ }
+
+ test("SPARK-23564: right outer join should filter out null join keys on left side") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+ testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 2fb587d50a4cb..565b0a10154a8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
}
}
- /** Set up tables and columns for testing */
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
- attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ attr("t1.k-1-2") -> rangeColumnStat(2, 0),
+ attr("t1.v-1-10") -> rangeColumnStat(10, 0),
+ attr("t2.k-1-5") -> rangeColumnStat(5, 0),
+ attr("t3.v-1-100") -> rangeColumnStat(100, 0),
+ attr("t4.k-1-2") -> rangeColumnStat(2, 0),
+ attr("t4.v-1-10") -> rangeColumnStat(10, 0),
+ attr("t5.k-1-5") -> rangeColumnStat(5, 0),
+ attr("t5.v-1-5") -> rangeColumnStat(5, 0)
))
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index 3964508e3a55e..f1ce7543ffdc1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
class PropagateEmptyRelationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
PruneFilters,
- PropagateEmptyRelation) :: Nil
+ PropagateEmptyRelation,
+ CollapseProject) :: Nil
}
object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] {
@@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest {
ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin,
PushDownPredicate,
- PruneFilters) :: Nil
+ PruneFilters,
+ CollapseProject) :: Nil
}
val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1)))
@@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest {
(true, false, Inner, Some(LocalRelation('a.int, 'b.int))),
(true, false, Cross, Some(LocalRelation('a.int, 'b.int))),
- (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
+ (true, false, LeftOuter,
+ Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))),
- (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)),
+ (true, false, FullOuter,
+ Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)),
(true, false, LeftAnti, Some(testRelation1)),
(true, false, LeftSemi, Some(LocalRelation('a.int))),
@@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest {
(false, true, Cross, Some(LocalRelation('a.int, 'b.int))),
(false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))),
(false, true, RightOuter,
- Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
- (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)),
+ Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
+ (false, true, FullOuter,
+ Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)),
(false, true, LeftAnti, Some(LocalRelation('a.int))),
(false, true, LeftSemi, Some(LocalRelation('a.int))),
@@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("propagate empty relation keeps the plan resolved") {
+ val query = testRelation1.join(
+ LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None)
+ val optimized = Optimize.execute(query.analyze)
+ assert(optimized.resolved)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
new file mode 100644
index 0000000000000..dae5e6f3ee3dd
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class RemoveRedundantSortsSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Remove Redundant Sorts", Once,
+ RemoveRedundantSorts) ::
+ Batch("Collapse Project", Once,
+ CollapseProject) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("remove redundant order by") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
+ val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
+ val optimized = Optimize.execute(unnecessaryReordered.analyze)
+ val correctAnswer = orderedPlan.limit(2).select('a).analyze
+ comparePlans(Optimize.execute(optimized), correctAnswer)
+ }
+
+ test("do not remove sort if the order is different") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
+ val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc)
+ val optimized = Optimize.execute(reorderedDifferently.analyze)
+ val correctAnswer = reorderedDifferently.analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("filters don't affect order") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
+ val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
+ val optimized = Optimize.execute(filteredAndReordered.analyze)
+ val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits don't affect order") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
+ val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
+ val optimized = Optimize.execute(filteredAndReordered.analyze)
+ val correctAnswer = orderedPlan.limit(Literal(10)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("different sorts are not simplified if limit is in between") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
+ .orderBy('a.asc)
+ val optimized = Optimize.execute(orderedPlan.analyze)
+ val correctAnswer = orderedPlan.analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("range is already sorted") {
+ val inputPlan = Range(1L, 1000L, 1, 10)
+ val orderedPlan = inputPlan.orderBy('id.asc)
+ val optimized = Optimize.execute(orderedPlan.analyze)
+ val correctAnswer = inputPlan.analyze
+ comparePlans(optimized, correctAnswer)
+
+ val reversedPlan = inputPlan.orderBy('id.desc)
+ val reversedOptimized = Optimize.execute(reversedPlan.analyze)
+ val reversedCorrectAnswer = reversedPlan.analyze
+ comparePlans(reversedOptimized, reversedCorrectAnswer)
+
+ val negativeStepInputPlan = Range(10L, 1L, -1, 10)
+ val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
+ val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
+ val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
+ comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
+ }
+
+ test("sort should not be removed when there is a node which doesn't guarantee any order") {
+ val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
+ val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
+ val optimized = Optimize.execute(groupedAndResorted.analyze)
+ val correctAnswer = groupedAndResorted.analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("remove two consecutive sorts") {
+ val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc)
+ val optimized = Optimize.execute(orderedTwice.analyze)
+ val correctAnswer = testRelation.orderBy('b.desc).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("remove sorts separated by Filter/Project operators") {
+ val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc)
+ val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze)
+ val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze
+ comparePlans(optimizedWithProject, correctAnswerWithProject)
+
+ val orderedTwiceWithFilter =
+ testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc)
+ val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze)
+ val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze
+ comparePlans(optimizedWithFilter, correctAnswerWithFilter)
+
+ val orderedTwiceWithBoth =
+ testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc)
+ val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze)
+ val correctAnswerWithBoth =
+ testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze
+ comparePlans(optimizedWithBoth, correctAnswerWithBoth)
+
+ val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc)
+ val optimizedThrice = Optimize.execute(orderedThrice.analyze)
+ val correctAnswerThrice = testRelation.select('b).where('b > Literal(0))
+ .select(('b + 1).as('c)).orderBy('c.asc).analyze
+ comparePlans(optimizedThrice, correctAnswerThrice)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
index ada6e2a43ea0f..d4d23ad69b2c2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
@@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
// F1 (fact table)
- attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f1_fk1") -> rangeColumnStat(100, 0),
+ attr("f1_fk2") -> rangeColumnStat(100, 0),
+ attr("f1_fk3") -> rangeColumnStat(100, 0),
+ attr("f1_c1") -> rangeColumnStat(100, 0),
+ attr("f1_c2") -> rangeColumnStat(100, 0),
// D1 (dimension)
- attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d1_pk") -> rangeColumnStat(100, 0),
+ attr("d1_c2") -> rangeColumnStat(50, 0),
+ attr("d1_c3") -> rangeColumnStat(50, 0),
// D2 (dimension)
- attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d2_pk") -> rangeColumnStat(20, 0),
+ attr("d2_c2") -> rangeColumnStat(10, 0),
+ attr("d2_c3") -> rangeColumnStat(10, 0),
// D3 (dimension)
- attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d3_pk") -> rangeColumnStat(10, 0),
+ attr("d3_c2") -> rangeColumnStat(5, 0),
+ attr("d3_c3") -> rangeColumnStat(5, 0),
// T1 (regular table i.e. outside star)
- attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("t1_c1") -> rangeColumnStat(20, 1),
+ attr("t1_c2") -> rangeColumnStat(10, 1),
+ attr("t1_c3") -> rangeColumnStat(10, 1),
// T2 (regular table)
- attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("t2_c1") -> rangeColumnStat(5, 1),
+ attr("t2_c2") -> rangeColumnStat(5, 1),
+ attr("t2_c3") -> rangeColumnStat(5, 1),
// T3 (regular table)
- attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("t3_c1") -> rangeColumnStat(5, 1),
+ attr("t3_c2") -> rangeColumnStat(5, 1),
+ attr("t3_c3") -> rangeColumnStat(5, 1),
// T4 (regular table)
- attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("t4_c1") -> rangeColumnStat(5, 1),
+ attr("t4_c2") -> rangeColumnStat(5, 1),
+ attr("t4_c3") -> rangeColumnStat(5, 1),
// T5 (regular table)
- attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
+ attr("t5_c1") -> rangeColumnStat(5, 1),
+ attr("t5_c2") -> rangeColumnStat(5, 1),
+ attr("t5_c3") -> rangeColumnStat(5, 1),
// T6 (regular table)
- attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 1, avgLen = 4, maxLen = 4)
+ attr("t6_c1") -> rangeColumnStat(5, 1),
+ attr("t6_c2") -> rangeColumnStat(5, 1),
+ attr("t6_c3") -> rangeColumnStat(5, 1)
))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
index 777c5637201ed..4e0883e91e84a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
@@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
// Tables' cardinality: f1 > d3 > d1 > d2 > s3
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
// F1
- attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("f1_fk1") -> rangeColumnStat(3, 0),
+ attr("f1_fk2") -> rangeColumnStat(3, 0),
+ attr("f1_fk3") -> rangeColumnStat(4, 0),
+ attr("f1_c4") -> rangeColumnStat(4, 0),
// D1
- attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d1_pk1") -> rangeColumnStat(4, 0),
+ attr("d1_c2") -> rangeColumnStat(3, 0),
+ attr("d1_c3") -> rangeColumnStat(4, 0),
+ attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
// D2
- attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 1, avgLen = 4, maxLen = 4),
- attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"),
+ nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)),
+ attr("d2_pk1") -> rangeColumnStat(3, 0),
+ attr("d2_c3") -> rangeColumnStat(3, 0),
+ attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
// D3
- attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("d3_fk1") -> rangeColumnStat(3, 0),
+ attr("d3_c2") -> rangeColumnStat(3, 0),
+ attr("d3_pk1") -> rangeColumnStat(5, 0),
+ attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
// S3
- attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("s3_pk1") -> rangeColumnStat(2, 0),
+ attr("s3_c2") -> rangeColumnStat(1, 0),
+ attr("s3_c3") -> rangeColumnStat(1, 0),
+ attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
// F11
- attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ attr("f11_fk1") -> rangeColumnStat(3, 0),
+ attr("f11_fk2") -> rangeColumnStat(3, 0),
+ attr("f11_fk3") -> rangeColumnStat(4, 0),
+ attr("f11_c4") -> rangeColumnStat(4, 0)
))
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala
new file mode 100644
index 0000000000000..09b11f5aba2a0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+
+class UpdateNullabilityInAttributeReferencesSuite extends PlanTest {
+
+ object Optimizer extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Constant Folding", FixedPoint(10),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ SimplifyConditionals,
+ SimplifyBinaryComparison,
+ SimplifyExtractValueOps) ::
+ Batch("UpdateAttributeReferences", Once,
+ UpdateNullabilityInAttributeReferences) :: Nil
+ }
+
+ test("update nullability in AttributeReference") {
+ val rel = LocalRelation('a.long.notNull)
+ // In the 'original' plans below, the Aggregate node produced by groupBy() has a
+ // nullable AttributeReference to `b`, because both array indexing and map lookup are
+ // nullable expressions. After optimization, the same attribute is now non-nullable,
+ // but the AttributeReference is not updated to reflect this. So, we need to update nullability
+ // by the `UpdateNullabilityInAttributeReferences` rule.
+ val original = rel
+ .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b")
+ .groupBy($"b")("1")
+ val expected = rel.select('a as "b").groupBy($"b")("1").analyze
+ val optimized = Optimizer.execute(original.analyze)
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index de544ac314789..5452e72b38647 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -44,14 +44,20 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
BooleanSimplification,
SimplifyConditionals,
SimplifyBinaryComparison,
- SimplifyCreateStructOps,
- SimplifyCreateArrayOps,
- SimplifyCreateMapOps) :: Nil
+ SimplifyExtractValueOps) :: Nil
}
- val idAtt = ('id).long.notNull
+ private val idAtt = ('id).long.notNull
+ private val nullableIdAtt = ('nullable_id).long
- lazy val relation = LocalRelation(idAtt )
+ private val relation = LocalRelation(idAtt, nullableIdAtt)
+ private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int)
+
+ private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
+ val optimized = Optimizer.execute(originalQuery.analyze)
+ assert(optimized.resolved, "optimized plans must be still resolvable")
+ comparePlans(optimized, correctAnswer.analyze)
+ }
test("explicit get from namedStruct") {
val query = relation
@@ -59,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField(
CreateNamedStruct(Seq("att", 'id )),
0,
- None) as "outerAtt").analyze
- val expected = relation.select('id as "outerAtt").analyze
+ None) as "outerAtt")
+ val expected = relation.select('id as "outerAtt")
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("explicit get from named_struct- expression maintains original deduced alias") {
val query = relation
.select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None))
- .analyze
val expected = relation
.select('id as "named_struct(att, id).att")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("collapsed getStructField ontop of namedStruct") {
val query = relation
.select(CreateNamedStruct(Seq("att", 'id)) as "struct1")
.select(GetStructField('struct1, 0, None) as "struct1Att")
- .analyze
- val expected = relation.select('id as "struct1Att").analyze
- comparePlans(Optimizer execute query, expected)
+ val expected = relation.select('id as "struct1Att")
+ checkRule(query, expected)
}
test("collapse multiple CreateNamedStruct/GetStructField pairs") {
@@ -95,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None) as "struct1Att1",
GetStructField('struct1, 1, None) as "struct1Att2")
- .analyze
val expected =
relation.
select(
'id as "struct1Att1",
('id * 'id) as "struct1Att2")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("collapsed2 - deduced names") {
@@ -116,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
GetStructField('struct1, 0, None),
GetStructField('struct1, 1, None))
- .analyze
val expected =
relation.
select(
'id as "struct1.att1",
('id * 'id) as "struct1.att2")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplified array ops") {
@@ -152,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
1,
false),
1) as "a4")
- .analyze
val expected = relation
.select(
@@ -162,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
"att2", (('id + 1L) * ('id + 1L)))) as "a2",
('id + 1L) as "a3",
('id + 1L) as "a4")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("SPARK-22570: CreateArray should not create a lot of global variables") {
@@ -189,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
GetStructField(GetMapValue('m, "r1"), 0, None) as "a2",
GetMapValue('m, "r32") as "a3",
GetStructField(GetMapValue('m, "r32"), 0, None) as "a4")
- .analyze
val expected =
relation.select(
@@ -202,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
)
) as "a3",
Literal.create(null, LongType) as "a4")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, constant lookup, dynamic keys") {
@@ -217,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
- .analyze
val expected = relation
.select(
@@ -226,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo(13L, ('id + 1L)), ('id + 2L)),
(EqualTo(13L, ('id + 2L)), ('id + 3L)),
(Literal(true), 'id))) as "a")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") {
@@ -241,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
('id + 3L)) as "a")
- .analyze
val expected = relation
.select(
CaseWhen(Seq(
@@ -249,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo('id + 3L, ('id + 1L)), ('id + 2L)),
(EqualTo('id + 3L, ('id + 2L)), ('id + 3L)),
(Literal(true), ('id + 4L)))) as "a")
- .analyze
- comparePlans(Optimizer execute query, expected)
+ checkRule(query, expected)
}
test("simplify map ops, no positive match") {
@@ -264,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
'id + 30L) as "a")
- .analyze
val expected = relation.select(
CaseWhen(Seq(
(EqualTo('id + 30L, 'id), ('id + 1L)),
@@ -272,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
(EqualTo('id + 30L, ('id + 2L)), ('id + 3L)),
(EqualTo('id + 30L, ('id + 3L)), ('id + 4L)),
(EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
}
test("simplify map ops, constant lookup, mixed keys, eliminated constants") {
@@ -288,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
13L) as "a")
- .analyze
val expected = relation
.select(
@@ -298,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 2L), ('id + 3L),
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
}
test("simplify map ops, potential dynamic match with null value + an absolute constant match") {
@@ -315,20 +302,154 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
('id + 3L), ('id + 4L),
('id + 4L), ('id + 5L))),
2L ) as "a")
- .analyze
val expected = relation
.select(
CaseWhen(Seq(
(EqualTo(2L, 'id), ('id + 1L)),
- // these two are possible matches, we can't tell untill runtime
+ // these two are possible matches, we can't tell until runtime
(EqualTo(2L, ('id + 1L)), ('id + 2L)),
(EqualTo(2L, 'id + 2L), Literal.create(null, LongType)),
// this is a definite match (two constants),
// but it cannot override a potential match with ('id + 2L),
// which is exactly what [[Coalesce]] would do in this case.
(Literal.TrueLiteral, 'id))) as "a")
- .analyze
- comparePlans(Optimizer execute rel, expected)
+ checkRule(rel, expected)
+ }
+
+ test("SPARK-23500: Simplify array ops that are not at the top node") {
+ val query = LocalRelation('id.long)
+ .select(
+ CreateArray(Seq(
+ CreateNamedStruct(Seq(
+ "att1", 'id,
+ "att2", 'id * 'id)),
+ CreateNamedStruct(Seq(
+ "att1", 'id + 1,
+ "att2", ('id + 1) * ('id + 1))
+ ))
+ ) as "arr")
+ .select(
+ GetStructField(GetArrayItem('arr, 1), 0, None) as "a1",
+ GetArrayItem(
+ GetArrayStructFields('arr,
+ StructField("att1", LongType, nullable = false),
+ ordinal = 0,
+ numFields = 1,
+ containsNull = false),
+ ordinal = 1) as "a2")
+ .orderBy('id.asc)
+
+ val expected = LocalRelation('id.long)
+ .select(
+ ('id + 1L) as "a1",
+ ('id + 1L) as "a2")
+ .orderBy('id.asc)
+ checkRule(query, expected)
+ }
+
+ test("SPARK-23500: Simplify map ops that are not top nodes") {
+ val query =
+ LocalRelation('id.long)
+ .select(
+ CreateMap(Seq(
+ "r1", 'id,
+ "r2", 'id + 1L)) as "m")
+ .select(
+ GetMapValue('m, "r1") as "a1",
+ GetMapValue('m, "r32") as "a2")
+ .orderBy('id.asc)
+ .select('a1, 'a2)
+
+ val expected =
+ LocalRelation('id.long).select(
+ 'id as "a1",
+ Literal.create(null, LongType) as "a2")
+ .orderBy('id.asc)
+ checkRule(query, expected)
+ }
+
+ test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
+ val structRel = relation
+ .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
+ .groupBy($"foo")("1")
+ val structExpected = relation
+ .select('nullable_id as "foo")
+ .groupBy($"foo")("1")
+ checkRule(structRel, structExpected)
+
+ val arrayRel = relation
+ .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
+ .groupBy($"a1")("1")
+ val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1")
+ checkRule(arrayRel, arrayExpected)
+
+ val mapRel = relation
+ .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
+ .groupBy($"m1")("1")
+ val mapExpected = relation
+ .select('nullable_id as "m1")
+ .groupBy($"m1")("1")
+ checkRule(mapRel, mapExpected)
+ }
+
+ test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
+ // Make sure that aggregation exprs are correctly ignored. Maps can't be used in
+ // grouping exprs so aren't tested here.
+ val structAggRel = relation.groupBy(
+ CreateNamedStruct(Seq("att1", 'nullable_id)))(
+ GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None))
+ checkRule(structAggRel, structAggRel)
+
+ val arrayAggRel = relation.groupBy(
+ CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
+ checkRule(arrayAggRel, arrayAggRel)
+
+ // This could be done if we had a more complex rule that checks that
+ // the CreateMap does not come from key.
+ val originalQuery = relation
+ .groupBy('id)(
+ GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
+ )
+ checkRule(originalQuery, originalQuery)
+ }
+
+ test("SPARK-23500: namedStruct and getField in the same Project #1") {
+ val originalQuery =
+ testRelation
+ .select(
+ namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b)
+ .select('s1 getField "col2" as 's1Col2,
+ namedStruct("col1", 'a, "col2", 'b).as("s2"))
+ .select('s1Col2, 's2 getField "col2" as 's2Col2)
+ val correctAnswer =
+ testRelation
+ .select('c as 's1Col2, 'b as 's2Col2)
+ checkRule(originalQuery, correctAnswer)
+ }
+
+ test("SPARK-23500: namedStruct and getField in the same Project #2") {
+ val originalQuery =
+ testRelation
+ .select(
+ namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2,
+ namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1)
+ val correctAnswer =
+ testRelation
+ .select('c as 'sCol2, 'a as 'sCol1)
+ checkRule(originalQuery, correctAnswer)
+ }
+
+ test("SPARK-24313: support binary type as map keys in GetMapValue") {
+ val mb0 = Literal.create(
+ Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+ MapType(BinaryType, StringType))
+ val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+ checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 812bfdd7bb885..fb51376c6163f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select * from t lateral view posexplode(x) posexpl as x, y",
expected)
+
+ intercept(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |pivot (
+ | sum(x)
+ | FOR y IN ('a', 'b')
+ |)""".stripMargin,
+ "LATERAL cannot be used together with PIVOT in FROM clause")
}
test("joins") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
index cc80a41df998d..ff0de0fb7c1f0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -41,17 +41,17 @@ class TableIdentifierParserSuite extends SparkFunSuite {
"sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables",
"tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive",
"undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp",
- "view", "while", "year", "work", "transaction", "write", "isolation", "level",
- "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint",
+ "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot",
+ "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint",
"binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp",
"cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external",
"false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in",
- "insert", "int", "into", "is", "lateral", "like", "local", "none", "null",
+ "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null",
"of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke",
"rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
"true", "truncate", "update", "user", "values", "with", "regexp", "rlike",
"bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float",
- "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing")
+ "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract")
val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right",
"natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 14041747fd20e..bf569cb869428 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
@@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite {
assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
}
+
+ test("transformExpressions works with a Stream") {
+ val id1 = NamedExpression.newExprId
+ val id2 = NamedExpression.newExprId
+ val plan = Project(Stream(
+ Alias(Literal(1), "a")(exprId = id1),
+ Alias(Literal(2), "b")(exprId = id2)),
+ OneRowRelation())
+ val result = plan.transformExpressions {
+ case Literal(v: Int, IntegerType) if v != 1 =>
+ Literal(v + 1, IntegerType)
+ }
+ val expected = Project(Stream(
+ Alias(Literal(1), "a")(exprId = id1),
+ Alias(Literal(3), "b")(exprId = id2)),
+ OneRowRelation())
+ assert(result.sameResult(expected))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
new file mode 100644
index 0000000000000..27914ef5565c0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.plans
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+import org.apache.spark.sql.types.IntegerType
+
+class QueryPlanSuite extends SparkFunSuite {
+
+ test("origin remains the same after mapExpressions (SPARK-23823)") {
+ CurrentOrigin.setPosition(0, 0)
+ val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
+ val query = plans.DslLogicalPlan(plans.table("table")).select(column)
+ CurrentOrigin.reset()
+
+ val mappedQuery = query mapExpressions {
+ case _: Expression => Literal(1)
+ }
+
+ val mappedOrigin = mappedQuery.expressions.apply(0).origin
+ assert(mappedOrigin == Origin.apply(Some(0), Some(0)))
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
index 23f95a6cc2ac2..8213d568fe85e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
/** Columns for testing */
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
- attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0,
- avgLen = 4, maxLen = 4)
+ attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
))
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
@@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
tableRowCount = 6,
groupByColumns = Seq("key21", "key22"),
// Row count = product of ndv
- expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2
- .distinctCount)
+ expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get *
+ nameToColInfo("key22")._2.distinctCount.get)
}
test("empty group-by column") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index 7d532ff343178..953094cb0dd52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType
class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
val attribute = attr("key")
- val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
val plan = StatsTestPlan(
outputList = Seq(attribute),
@@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
sizeInBytes = 40,
rowCount = Some(10),
attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))))
+ AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10),
+ min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))))
val expectedCboStats =
Statistics(
sizeInBytes = 4,
rowCount = Some(1),
attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))))
+ AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10),
+ min = Some(5), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))))
val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
checkStats(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 2b1fe987a7960..47bfa62569583 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
// Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
val attrInt = AttributeReference("cint", IntegerType)()
- val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// column cbool has only 2 distinct values
val attrBool = AttributeReference("cbool", BooleanType)()
- val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1)
+ val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))
// column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10"))
val attrDate = AttributeReference("cdate", DateType)()
- val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStatDate = ColumnStat(distinctCount = Some(10),
+ min = Some(dMin), max = Some(dMax),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = Decimal("0.200000000000000000")
val decMax = Decimal("0.800000000000000000")
val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
- val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
- nullCount = 0, avgLen = 8, maxLen = 8)
+ val colStatDecimal = ColumnStat(distinctCount = Some(4),
+ min = Some(decMin), max = Some(decMax),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
// column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
val attrDouble = AttributeReference("cdouble", DoubleType)()
- val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
- nullCount = 0, avgLen = 8, maxLen = 8)
+ val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
// column cstring has 10 String values:
// "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
val attrString = AttributeReference("cstring", StringType)()
- val colStatString = ColumnStat(distinctCount = 10, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2)
+ val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))
// column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
// Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4
// This column is created to test "cint < cint2
val attrInt2 = AttributeReference("cint2", IntegerType)()
- val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39
// Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4
// This column is created to test "cint = cint3 without overlap at all.
val attrInt3 = AttributeReference("cint3", IntegerType)()
- val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// column cint4 has values in the range from 1 to 10
// distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
// This column is created to test complete overlap
val attrInt4 = AttributeReference("cint4", IntegerType)()
- val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram.
// Note that cintHgm has an even distribution with histogram information built.
@@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2),
HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2),
HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2)))
- val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))
+ val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))
// column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram.
// Note that cintSkewHgm has a skewed distribution with histogram information built.
@@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2),
HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1),
HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2)))
- val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))
+ val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))
val attributeMap = AttributeMap(Seq(
attrInt -> colStatInt,
@@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 3)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))),
expectedRowCount = 3)
}
@@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 8)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))),
expectedRowCount = 8)
}
@@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 8)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))),
expectedRowCount = 8)
}
test("cint = 2") {
validateEstimatedStats(
Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 1)
}
test("cint <=> 2") {
validateEstimatedStats(
Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 1)
}
@@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint < 3") {
validateEstimatedStats(
Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}
@@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint <= 3") {
validateEstimatedStats(
Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}
test("cint > 6") {
validateEstimatedStats(
Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 5)
}
@@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint >= 6") {
validateEstimatedStats(
Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 5)
}
@@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cint IS NOT NULL") {
validateEstimatedStats(
Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 10)
}
@@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 4)
}
@@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 2)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))),
expectedRowCount = 2)
}
@@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 6)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))),
expectedRowCount = 6)
}
@@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 5)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))),
expectedRowCount = 5)
}
@@ -342,47 +344,70 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 9),
- attrString -> colStatString.copy(distinctCount = 9)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)),
+ attrString -> colStatString.copy(distinctCount = Some(9))),
expectedRowCount = 9)
}
test("cint IN (3, 4, 5)") {
validateEstimatedStats(
Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}
+ test("evaluateInSet with all zeros") {
+ validateEstimatedStats(
+ Filter(InSet(attrString, Set(3, 4, 5)),
+ StatsTestPlan(Seq(attrString), 0,
+ AttributeMap(Seq(attrString ->
+ ColumnStat(distinctCount = Some(0), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))),
+ Seq(attrString -> ColumnStat(distinctCount = Some(0))),
+ expectedRowCount = 0)
+ }
+
+ test("evaluateInSet with string") {
+ validateEstimatedStats(
+ Filter(InSet(attrString, Set("A0")),
+ StatsTestPlan(Seq(attrString), 10,
+ AttributeMap(Seq(attrString ->
+ ColumnStat(distinctCount = Some(10), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))),
+ Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))),
+ expectedRowCount = 1)
+ }
+
test("cint NOT IN (3, 4, 5)") {
validateEstimatedStats(
Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
- Seq(attrInt -> colStatInt.copy(distinctCount = 7)),
+ Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))),
expectedRowCount = 7)
}
test("cbool IN (true)") {
validateEstimatedStats(
Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
- Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1)),
+ Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))),
expectedRowCount = 5)
}
test("cbool = true") {
validateEstimatedStats(
Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
- Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1)),
+ Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))),
expectedRowCount = 5)
}
test("cbool > false") {
validateEstimatedStats(
Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)),
- Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1)),
+ Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))),
expectedRowCount = 5)
}
@@ -391,18 +416,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(EqualTo(attrDate, Literal(d20170102, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
- Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrDate -> ColumnStat(distinctCount = Some(1),
+ min = Some(d20170102), max = Some(d20170102),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 1)
}
test("cdate < cast('2017-01-03' AS DATE)") {
+ val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
validateEstimatedStats(
Filter(LessThan(attrDate, Literal(d20170103, DateType)),
childStatsTestPlan(Seq(attrDate), 10L)),
- Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrDate -> ColumnStat(distinctCount = Some(3),
+ min = Some(d20170101), max = Some(d20170103),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}
@@ -414,8 +442,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType),
Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)),
- Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrDate -> ColumnStat(distinctCount = Some(3),
+ min = Some(d20170103), max = Some(d20170105),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}
@@ -424,42 +453,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
- Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
- nullCount = 0, avgLen = 8, maxLen = 8)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = Some(1),
+ min = Some(dec_0_40), max = Some(dec_0_40),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
expectedRowCount = 1)
}
test("cdecimal < 0.60 ") {
+ val dec_0_20 = Decimal("0.200000000000000000")
val dec_0_60 = Decimal("0.600000000000000000")
validateEstimatedStats(
Filter(LessThan(attrDecimal, Literal(dec_0_60)),
childStatsTestPlan(Seq(attrDecimal), 4L)),
- Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
- nullCount = 0, avgLen = 8, maxLen = 8)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = Some(3),
+ min = Some(dec_0_20), max = Some(dec_0_60),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
expectedRowCount = 3)
}
test("cdouble < 3.0") {
validateEstimatedStats(
Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)),
- Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0),
- nullCount = 0, avgLen = 8, maxLen = 8)),
+ Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
expectedRowCount = 3)
}
test("cstring = 'A2'") {
validateEstimatedStats(
Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
- Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2)),
+ Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))),
expectedRowCount = 1)
}
test("cstring < 'A2' - unsupported condition") {
validateEstimatedStats(
Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
- Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2)),
+ Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))),
expectedRowCount = 10)
}
@@ -468,8 +500,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// valid values in IN clause is greater than the number of distinct values for a given column.
// For example, column has only 2 distinct values 1 and 6.
// The predicate is: column IN (1, 2, 3, 4, 5).
- val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4)
+ val cornerChildColStatInt = ColumnStat(distinctCount = Some(2),
+ min = Some(1), max = Some(6),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
val cornerChildStatsTestplan = StatsTestPlan(
outputList = Seq(attrInt),
rowCount = 2L,
@@ -477,16 +510,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
)
validateEstimatedStats(
Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
- Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 2)
}
// This is a limitation test. We should remove it after the limitation is removed.
test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") {
val attrIntLargerRange = AttributeReference("c1", IntegerType)()
- val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20),
- nullCount = 10, avgLen = 4, maxLen = 4)
+ val colStatIntLargerRange = ColumnStat(distinctCount = Some(20),
+ min = Some(1), max = Some(20),
+ nullCount = Some(10), avgLen = Some(4), maxLen = Some(4))
val smallerTable = childStatsTestPlan(Seq(attrInt), 10L)
val largerTable = StatsTestPlan(
outputList = Seq(attrIntLargerRange),
@@ -508,10 +542,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 4)
}
@@ -519,10 +553,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 4)
}
@@ -530,10 +564,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 4)
}
@@ -541,10 +575,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// complete overlap case
validateEstimatedStats(
Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 10)
}
@@ -552,10 +586,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// partial overlap case
validateEstimatedStats(
Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 4)
}
@@ -571,10 +605,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// all table records qualify.
validateEstimatedStats(
Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
- Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39),
- nullCount = 0, avgLen = 4, maxLen = 4)),
+ Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 10)
}
@@ -592,11 +626,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)),
Seq(
- attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4),
- attrString -> colStatString.copy(distinctCount = 5)),
+ attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attrString -> colStatString.copy(distinctCount = Some(5))),
expectedRowCount = 5)
}
@@ -606,15 +640,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)),
+ Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))),
expectedRowCount = 7)
}
test("cintHgm = 5") {
validateEstimatedStats(
Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 1)
}
@@ -629,8 +663,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cintHgm < 3") {
validateEstimatedStats(
Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 3)
}
@@ -645,16 +679,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cintHgm <= 3") {
validateEstimatedStats(
Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 3)
}
test("cintHgm > 6") {
validateEstimatedStats(
Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 4)
}
@@ -669,8 +703,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cintHgm >= 6") {
validateEstimatedStats(
Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 5)
}
@@ -679,8 +713,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))),
+ Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))),
expectedRowCount = 4)
}
@@ -688,7 +722,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)),
- Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)),
+ Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))),
expectedRowCount = 3)
}
@@ -698,15 +732,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)),
+ Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))),
expectedRowCount = 9)
}
test("cintSkewHgm = 5") {
validateEstimatedStats(
Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 4)
}
@@ -721,8 +755,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
test("cintSkewHgm < 3") {
validateEstimatedStats(
Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 2)
}
@@ -738,16 +772,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)),
childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 2)
}
test("cintSkewHgm > 6") {
validateEstimatedStats(
Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 2)
}
@@ -764,8 +798,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
validateEstimatedStats(
Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)),
childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 3)
}
@@ -774,8 +808,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))),
+ Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))),
expectedRowCount = 8)
}
@@ -783,7 +817,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6)))
validateEstimatedStats(
Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)),
- Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)),
+ Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))),
expectedRowCount = 3)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index 26139d85d25fb..12c0a7be21292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
/** Set up tables and its columns for testing */
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
- attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0,
- avgLen = 4, maxLen = 4),
- attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0,
- avgLen = 4, maxLen = 4)
+ attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
))
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
@@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
private def estimateByHistogram(
leftHistogram: Histogram,
rightHistogram: Histogram,
- expectedMin: Double,
- expectedMax: Double,
+ expectedMin: Any,
+ expectedMax: Any,
expectedNdv: Long,
expectedRows: Long): Unit = {
val col1 = attr("key1")
@@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(expectedRows),
attributeStats = AttributeMap(Seq(
col1 -> c1.stats.attributeStats(col1).copy(
- distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)),
+ distinctCount = Some(expectedNdv),
+ min = Some(expectedMin), max = Some(expectedMax)),
col2 -> c2.stats.attributeStats(col2).copy(
- distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax))))
+ distinctCount = Some(expectedNdv),
+ min = Some(expectedMin), max = Some(expectedMax))))
)
// Join order should not affect estimation result.
@@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
private def generateJoinChild(
col: Attribute,
histogram: Histogram,
- expectedMin: Double,
- expectedMax: Double): LogicalPlan = {
- val colStat = inferColumnStat(histogram)
+ expectedMin: Any,
+ expectedMax: Any): LogicalPlan = {
+ val colStat = inferColumnStat(histogram, expectedMin, expectedMax)
StatsTestPlan(
outputList = Seq(col),
rowCount = (histogram.height * histogram.bins.length).toLong,
@@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
}
/** Column statistics should be consistent with histograms in tests. */
- private def inferColumnStat(histogram: Histogram): ColumnStat = {
+ private def inferColumnStat(
+ histogram: Histogram,
+ expectedMin: Any,
+ expectedMax: Any): ColumnStat = {
+
var ndv = 0L
for (i <- histogram.bins.indices) {
val bin = histogram.bins(i)
@@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
ndv += bin.ndv
}
}
- ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo),
- max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4,
+ ColumnStat(distinctCount = Some(ndv),
+ min = Some(expectedMin), max = Some(expectedMax),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4),
histogram = Some(histogram))
}
@@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(5 + 3),
attributeStats = AttributeMap(
// Update null count in column stats.
- Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3),
- nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3),
- nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5),
- nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5))))
+ Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)),
+ nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)),
+ nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)),
+ nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5)))))
assert(join.stats == expectedStats)
}
@@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
val join = Join(table1, table2, Inner,
Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2"))))
// Update column stats for equi-join keys (key-1-5 and key-1-2).
- val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4)
+ val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
// Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it
// unchanged (key-2-4).
- val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5)
+ val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5))
val expectedStats = Statistics(
sizeInBytes = 3 * (8 + 4 * 4),
@@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))))
// Update column stats for join keys.
- val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
- avgLen = 4, maxLen = 4)
- val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0,
- avgLen = 4, maxLen = 4)
+ val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
+ val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
val expectedStats = Statistics(
sizeInBytes = 2 * (8 + 4 * 4),
@@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table3, table2, LeftOuter,
Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4"))))
- val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0,
- avgLen = 4, maxLen = 4)
+ val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
val expectedStats = Statistics(
sizeInBytes = 2 * (8 + 4 * 4),
@@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
// table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3)
val join = Join(table2, table3, RightOuter,
Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))
- val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0,
- avgLen = 4, maxLen = 4)
+ val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
val expectedStats = Statistics(
sizeInBytes = 2 * (8 + 4 * 4),
@@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
mutable.LinkedHashMap[Attribute, ColumnStat](
- AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
- min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
- AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
- AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
- AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
- min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
- min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
- AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
- AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
- min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
- min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
- AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
- min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
- AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1,
- min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
- AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1,
- min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1,
- min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8)
+ AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(false), max = Some(false),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)),
+ AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1.toByte), max = Some(1.toByte),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)),
+ AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1.toShort), max = Some(1.toShort),
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)),
+ AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1), max = Some(1),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1L), max = Some(1L),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)),
+ AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1.0), max = Some(1.0),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)),
+ AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(1.0f), max = Some(1.0f),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(
+ distinctCount = Some(1), min = Some(dec), max = Some(dec),
+ nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)),
+ AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1),
+ min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)),
+ AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1),
+ min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)),
+ AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(date), max = Some(date),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1),
+ min = Some(timestamp), max = Some(timestamp),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
)
}
@@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
test("join with null column") {
val (nullColumn, nullColStat) = (attr("cnull"),
- ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4))
+ ColumnStat(distinctCount = Some(0), min = None, max = None,
+ nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)))
val nullTable = StatsTestPlan(
outputList = Seq(nullColumn),
rowCount = 1,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index cda54fa9d64f4..dcb37017329fc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -28,10 +28,10 @@ import org.apache.spark.sql.types._
class ProjectEstimationSuite extends StatsEstimationTestBase {
test("project with alias") {
- val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1),
- max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4))
- val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10),
- max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4))
+ val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1),
+ max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))
+ val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10),
+ max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))
val child = StatsTestPlan(
outputList = Seq(ar1, ar2),
@@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
}
test("project on empty table") {
- val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None,
- nullCount = 0, avgLen = 4, maxLen = 4))
+ val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None,
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))
val child = StatsTestPlan(
outputList = Seq(ar1),
rowCount = 0,
@@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02"))
val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
- AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
- min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
- AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
- AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
- AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
- min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
- min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
- AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
- AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
- min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
- min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
- AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
- min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
- AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2,
- min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3),
- AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2,
- min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4),
- AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2,
- min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8)
+ AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(false), max = Some(true),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)),
+ AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1), max = Some(2),
+ nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)),
+ AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1), max = Some(3),
+ nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)),
+ AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1), max = Some(4),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1), max = Some(5),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)),
+ AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1.0), max = Some(6.0),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)),
+ AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(1.0), max = Some(7.0),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(
+ distinctCount = Some(2), min = Some(dec1), max = Some(dec2),
+ nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)),
+ AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2),
+ min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)),
+ AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2),
+ min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)),
+ AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(d1), max = Some(d2),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
+ AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2),
+ min = Some(t1), max = Some(t2),
+ nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
))
val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2)))
val child = StatsTestPlan(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index 31dea2e3e7f1d..9dceca59f5b87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite {
def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
// For UTF8String: base + offset + numBytes
- case StringType => colStat.avgLen + 8 + 4
- case _ => colStat.avgLen
+ case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4
+ case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize)
}
def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()
@@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite {
val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap
AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2))
}
+
+ /** Get a test ColumnStat with given distinctCount and nullCount */
+ def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat =
+ ColumnStat(distinctCount = Some(distinctCount),
+ min = Some(1), max = Some(distinctCount),
+ nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 84d0ba7bef642..b7092f4c42d4c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition}
-import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback {
@@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite {
val right = JsonMethods.parse(rightJson)
assert(left == right)
}
+
+ test("transform works on stream of children") {
+ val before = Coalesce(Stream(Literal(1), Literal(2)))
+ // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the
+ // situation in which the TreeNode.mapChildren function's change detection is not triggered. A
+ // stream's first element is typically materialized, so in order to not trip the TreeNode change
+ // detection logic, we should not change the first element in the sequence.
+ val result = before.transform {
+ case Literal(v: Int, IntegerType) if v != 1 =>
+ Literal(v + 1, IntegerType)
+ }
+ val expected = Coalesce(Stream(Literal(1), Literal(3)))
+ assert(result === expected)
+ }
+
+ test("withNewChildren on stream of children") {
+ val before = Coalesce(Stream(Literal(1), Literal(2)))
+ val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
+ val expected = Coalesce(Stream(Literal(1), Literal(3)))
+ assert(result === expected)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
new file mode 100644
index 0000000000000..6400898343ae7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection}
+import org.apache.spark.sql.types._
+
+class ArrayDataIndexedSeqSuite extends SparkFunSuite {
+ private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = {
+ assert(arrayData.numElements == array.length)
+ array.zipWithIndex.map { case (e, i) =>
+ if (e != null) {
+ elementDt match {
+ // For NaN, etc.
+ case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e))
+ case _ => assert(arrayData.get(i, elementDt) === e)
+ }
+ } else {
+ assert(arrayData.isNullAt(i))
+ }
+ }
+
+ val seq = arrayData.toSeq[Any](elementDt)
+ array.zipWithIndex.map { case (e, i) =>
+ if (e != null) {
+ elementDt match {
+ // For Nan, etc.
+ case FloatType | DoubleType => assert(seq(i).equals(e))
+ case _ => assert(seq(i) === e)
+ }
+ } else {
+ assert(seq(i) == null)
+ }
+ }
+
+ intercept[IndexOutOfBoundsException] {
+ seq(-1)
+ }.getMessage().contains("must be between 0 and the length of the ArrayData.")
+
+ intercept[IndexOutOfBoundsException] {
+ seq(seq.length)
+ }.getMessage().contains("must be between 0 and the length of the ArrayData.")
+ }
+
+ private def testArrayData(): Unit = {
+ val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
+ DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
+ CalendarIntervalType, new ExamplePointUDT())
+ val arrayTypes = elementTypes.flatMap { elementType =>
+ Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
+ }
+ val random = new Random(100)
+ arrayTypes.foreach { dt =>
+ val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil)
+ val row = RandomDataGenerator.randomRow(random, schema)
+ val rowConverter = RowEncoder(schema)
+ val internalRow = rowConverter.toRow(row)
+
+ val unsafeRowConverter = UnsafeProjection.create(schema)
+ val safeRowConverter = FromUnsafeProjection(schema)
+
+ val unsafeRow = unsafeRowConverter(internalRow)
+ val safeRow = safeRowConverter(unsafeRow)
+
+ val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
+ val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
+
+ val elementType = dt.elementType
+ test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
+ compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType))
+ }
+
+ test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
+ compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType))
+ }
+ }
+ }
+
+ testArrayData()
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 625ff38943fa3..cbf6106697f30 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c1.set(1997, 1, 28, 10, 30, 0)
val c2 = Calendar.getInstance()
c2.set(1996, 9, 30, 0, 0, 0)
- assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677)
- c2.set(2000, 1, 28, 0, 0, 0)
- assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36)
- c2.set(2000, 1, 29, 0, 0, 0)
- assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36)
- c2.set(1996, 2, 31, 0, 0, 0)
- assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11)
+ assert(monthsBetween(
+ c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677)
+ assert(monthsBetween(
+ c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone)
+ === 3.9495967741935485)
+ Seq(true, false).foreach { roundOff =>
+ c2.set(2000, 1, 28, 0, 0, 0)
+ assert(monthsBetween(
+ c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36)
+ c2.set(2000, 1, 29, 0, 0, 0)
+ assert(monthsBetween(
+ c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36)
+ c2.set(1996, 2, 31, 0, 0, 0)
+ assert(monthsBetween(
+ c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11)
+ }
val c3 = Calendar.getInstance(TimeZonePST)
c3.set(2000, 1, 28, 16, 0, 0)
val c4 = Calendar.getInstance(TimeZonePST)
c4.set(1997, 1, 28, 16, 0, 0)
assert(
- monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST)
+ monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST)
=== 36.0)
assert(
- monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT)
+ monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT)
=== 35.90322581)
+ assert(
+ monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT)
+ === 35.903225806451616)
}
test("from UTC timestamp") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala
new file mode 100644
index 0000000000000..b75739e5a3a65
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+
+class RandomUUIDGeneratorSuite extends SparkFunSuite {
+ test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") {
+ val generator = RandomUUIDGenerator(new Random().nextLong())
+ for (_ <- 0 to 100) {
+ val uuid = generator.getNextUUID()
+ assert(uuid.version() == 4)
+ assert(uuid.variant() == 2)
+ }
+ }
+
+ test("UUID from RandomUUIDGenerator should be deterministic") {
+ val r1 = new Random(100)
+ val generator1 = RandomUUIDGenerator(r1.nextLong())
+ val r2 = new Random(100)
+ val generator2 = RandomUUIDGenerator(r2.nextLong())
+ val r3 = new Random(101)
+ val generator3 = RandomUUIDGenerator(r3.nextLong())
+
+ for (_ <- 0 to 100) {
+ val uuid1 = generator1.getNextUUID()
+ val uuid2 = generator2.getNextUUID()
+ val uuid3 = generator3.getNextUUID()
+ assert(uuid1 == uuid2)
+ assert(uuid1 != uuid3)
+ }
+ }
+
+ test("Get UTF8String UUID") {
+ val generator = RandomUUIDGenerator(new Random().nextLong())
+ val utf8StringUUID = generator.getNextUUIDUTF8String()
+ val uuid = java.util.UUID.fromString(utf8StringUUID.toString)
+ assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 8e2b32c2b9a08..5a86f4055dce7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite {
assert(mapped === expected)
}
+ test("fieldNames and names returns field names") {
+ val struct = StructType(
+ StructField("a", LongType) :: StructField("b", FloatType) :: Nil)
+
+ assert(struct.fieldNames === Seq("a", "b"))
+ assert(struct.names === Seq("a", "b"))
+ }
+
test("merge where right contains type conflict") {
val left = StructType(
StructField("a", LongType) ::
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
new file mode 100644
index 0000000000000..c6ca8bb005429
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.SparkFunSuite
+
+class StructTypeSuite extends SparkFunSuite {
+
+ val s = StructType.fromDDL("a INT, b STRING")
+
+ test("lookup a single missing field should output existing fields") {
+ val e = intercept[IllegalArgumentException](s("c")).getMessage
+ assert(e.contains("Available fields: a, b"))
+ }
+
+ test("lookup a set of missing fields should output existing fields") {
+ val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage
+ assert(e.contains("Available fields: a, b"))
+ }
+
+ test("lookup fieldIndex for missing field should output existing fields") {
+ val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage
+ assert(e.contains("Available fields: a, b"))
+ }
+}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index ef41837f89d68..f270c70fbfcf0 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -38,7 +38,7 @@
com.univocityunivocity-parsers
- 2.5.9
+ 2.6.3jar
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index 730a4ae8d5605..74c9c05992719 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -62,10 +62,14 @@ public long durationMs() {
*/
public abstract void init(int index, Iterator[] iters);
+ /*
+ * Attributes of the following four methods are public. Thus, they can be also accessed from
+ * methods in inner classes. See SPARK-23598
+ */
/**
* Append a row to currentRows.
*/
- protected void append(InternalRow row) {
+ public void append(InternalRow row) {
currentRows.add(row);
}
@@ -75,7 +79,7 @@ protected void append(InternalRow row) {
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
* This interface is mainly used to limit the number of input rows.
*/
- protected boolean stopEarly() {
+ public boolean stopEarly() {
return false;
}
@@ -84,14 +88,14 @@ protected boolean stopEarly() {
*
* If it returns true, the caller should exit the loop (return from processNext()).
*/
- protected boolean shouldStop() {
+ public boolean shouldStop() {
return !currentRows.isEmpty();
}
/**
* Increase the peak execution memory for current task.
*/
- protected void incPeakExecutionMemory(long size) {
+ public void incPeakExecutionMemory(long size) {
TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index b0b5383a081a0..9eb03430a7db2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -34,6 +34,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.unsafe.sort.*;
@@ -98,19 +99,33 @@ public UnsafeKVExternalSorter(
numElementsForSpillThreshold,
canUseRadixSort);
} else {
- // The array will be used to do in-place sort, which require half of the space to be empty.
- // Note: each record in the map takes two entries in the array, one is record pointer,
- // another is the key prefix.
- assert(map.numKeys() * 2 <= map.getArray().size() / 2);
- // During spilling, the array in map will not be used, so we can borrow that and use it
- // as the underlying array for in-memory sorter (it's always large enough).
- // Since we will not grow the array, it's fine to pass `null` as consumer.
+ // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow
+ // that and use it as the pointer array for `UnsafeInMemorySorter`.
+ LongArray pointerArray = map.getArray();
+ // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but
+ // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since
+ // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer
+ // array can hold all the entries in `BytesToBytesMap`.
+ // The pointer array will be used to do in-place sort, which requires half of the space to be
+ // empty. Note: each record in the map takes two entries in the pointer array, one is record
+ // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`.
+ // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key,
+ // so that we can always reuse the pointer array.
+ if (map.numValues() > pointerArray.size() / 4) {
+ // Here we ask the map to allocate memory, so that the memory manager won't ask the map
+ // to spill, if the memory is not enough.
+ pointerArray = map.allocateArray(map.numValues() * 4L);
+ }
+
+ // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed
+ // to be large enough, it's fine to pass `null` as consumer because we won't allocate more
+ // memory.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
null,
taskMemoryManager,
comparatorSupplier.get(),
prefixComparator,
- map.getArray(),
+ pointerArray,
canUseRadixSort);
// We cannot use the destructive iterator here because we are reusing the existing memory
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java
new file mode 100644
index 0000000000000..82a1169cbe7ae
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * Exception thrown when the parquet reader find column type mismatches.
+ */
+@InterfaceStability.Unstable
+public class SchemaColumnConvertNotSupportedException extends RuntimeException {
+
+ /**
+ * Name of the column which cannot be converted.
+ */
+ private String column;
+ /**
+ * Physical column type in the actual parquet file.
+ */
+ private String physicalType;
+ /**
+ * Logical column type in the parquet schema the parquet reader use to parse all files.
+ */
+ private String logicalType;
+
+ public String getColumn() {
+ return column;
+ }
+
+ public String getPhysicalType() {
+ return physicalType;
+ }
+
+ public String getLogicalType() {
+ return logicalType;
+ }
+
+ public SchemaColumnConvertNotSupportedException(
+ String column,
+ String physicalType,
+ String logicalType) {
+ super();
+ this.column = column;
+ this.physicalType = physicalType;
+ this.logicalType = logicalType;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
index 12f4d658b1868..9bfad1e83ee7b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java
@@ -136,7 +136,7 @@ public int getInt(int rowId) {
public long getLong(int rowId) {
int index = getRowIndex(rowId);
if (isTimestamp) {
- return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000;
+ return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000;
} else {
return longData.vector[index];
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
index dcebdc39f0aa2..a0d9578a377b1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
@@ -497,7 +497,7 @@ private void putValues(
* Returns the number of micros since epoch from an element of TimestampColumnVector.
*/
private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) {
- return vector.time[index] * 1000L + vector.nanos[index] / 1000L;
+ return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000);
}
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index e65cd252c3ddf..c975e52734e01 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.parquet;
-import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
@@ -147,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString);
this.reader = new ParquetFileReader(
configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns());
- for (BlockMetaData block : blocks) {
+ // use the blocks from the reader in case some do not match filters and will not be read
+ for (BlockMetaData block : reader.getRowGroups()) {
this.totalRowCount += block.getRowCount();
}
@@ -225,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException
this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema);
this.reader = new ParquetFileReader(
config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns());
- for (BlockMetaData block : blocks) {
+ // use the blocks from the reader in case some do not match filters and will not be read
+ for (BlockMetaData block : reader.getRowGroups()) {
this.totalRowCount += block.getRowCount();
}
}
@@ -293,7 +294,7 @@ protected static IntIterator createRLEIterator(
return new RLEIntIterator(
new RunLengthBitPackingHybridDecoder(
BytesUtils.getWidthFromMaxInt(maxLevel),
- new ByteArrayInputStream(bytes.toByteArray())));
+ bytes.toInputStream()));
} catch (IOException e) {
throw new IOException("could not read levels in page for col " + descriptor, e);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index c120863152a96..d5969b55eef96 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -18,8 +18,11 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
+import java.util.Arrays;
import java.util.TimeZone;
+import org.apache.parquet.bytes.ByteBufferInputStream;
+import org.apache.parquet.bytes.BytesInput;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.Dictionary;
@@ -31,6 +34,7 @@
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
+import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
@@ -231,6 +235,18 @@ private boolean shouldConvertTimestamps() {
return convertTz != null && !convertTz.equals(UTC);
}
+ /**
+ * Helper function to construct exception for parquet schema mismatch.
+ */
+ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException(
+ ColumnDescriptor descriptor,
+ WritableColumnVector column) {
+ return new SchemaColumnConvertNotSupportedException(
+ Arrays.toString(descriptor.getPath()),
+ descriptor.getType().toString(),
+ column.dataType().toString());
+ }
+
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
@@ -261,7 +277,7 @@ private void decodeDictionaryIds(
}
}
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
break;
@@ -282,7 +298,7 @@ private void decodeDictionaryIds(
}
}
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
break;
@@ -321,7 +337,7 @@ private void decodeDictionaryIds(
}
}
} else {
- throw new UnsupportedOperationException();
+ throw constructConvertNotSupportedException(descriptor, column);
}
break;
case BINARY:
@@ -360,7 +376,7 @@ private void decodeDictionaryIds(
}
}
} else {
- throw new UnsupportedOperationException();
+ throw constructConvertNotSupportedException(descriptor, column);
}
break;
@@ -374,13 +390,16 @@ private void decodeDictionaryIds(
* is guaranteed that num is smaller than the number of values left in the current page.
*/
- private void readBooleanBatch(int rowId, int num, WritableColumnVector column) {
- assert(column.dataType() == DataTypes.BooleanType);
+ private void readBooleanBatch(int rowId, int num, WritableColumnVector column)
+ throws IOException {
+ if (column.dataType() != DataTypes.BooleanType) {
+ throw constructConvertNotSupportedException(descriptor, column);
+ }
defColumn.readBooleans(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
}
- private void readIntBatch(int rowId, int num, WritableColumnVector column) {
+ private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
@@ -394,11 +413,11 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) {
defColumn.readShorts(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
- private void readLongBatch(int rowId, int num, WritableColumnVector column) {
+ private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
DecimalType.is64BitDecimalType(column.dataType()) ||
@@ -414,37 +433,38 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) {
}
}
} else {
- throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
- private void readFloatBatch(int rowId, int num, WritableColumnVector column) {
+ private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: support implicit cast to double?
if (column.dataType() == DataTypes.FloatType) {
defColumn.readFloats(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
- throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
- private void readDoubleBatch(int rowId, int num, WritableColumnVector column) {
+ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.DoubleType) {
defColumn.readDoubles(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
- private void readBinaryBatch(int rowId, int num, WritableColumnVector column) {
+ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
- if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) {
+ if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType
+ || DecimalType.isByteArrayDecimalType(column.dataType())) {
defColumn.readBinarys(num, column, rowId, maxDefLevel, data);
} else if (column.dataType() == DataTypes.TimestampType) {
if (!shouldConvertTimestamps()) {
@@ -470,7 +490,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) {
}
}
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
@@ -509,7 +529,7 @@ private void readFixedLenByteArrayBatch(
}
}
} else {
- throw new UnsupportedOperationException("Unimplemented type: " + column.dataType());
+ throw constructConvertNotSupportedException(descriptor, column);
}
}
@@ -539,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) {
});
}
- private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException {
+ private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException {
this.endOfPageValueCount = valuesRead + pageValueCount;
if (dataEncoding.usesDictionary()) {
this.dataColumn = null;
@@ -564,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr
}
try {
- dataColumn.initFromPage(pageValueCount, bytes, offset);
+ dataColumn.initFromPage(pageValueCount, in);
} catch (IOException e) {
throw new IOException("could not read page in col " + descriptor, e);
}
@@ -585,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException {
this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader);
this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader);
try {
- byte[] bytes = page.getBytes().toByteArray();
- rlReader.initFromPage(pageValueCount, bytes, 0);
- int next = rlReader.getNextOffset();
- dlReader.initFromPage(pageValueCount, bytes, next);
- next = dlReader.getNextOffset();
- initDataReader(page.getValueEncoding(), bytes, next);
+ BytesInput bytes = page.getBytes();
+ ByteBufferInputStream in = bytes.toInputStream();
+ rlReader.initFromPage(pageValueCount, in);
+ dlReader.initFromPage(pageValueCount, in);
+ initDataReader(page.getValueEncoding(), in);
} catch (IOException e) {
throw new IOException("could not read page " + page + " in col " + descriptor, e);
}
@@ -602,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException {
page.getRepetitionLevels(), descriptor);
int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
- this.defColumn = new VectorizedRleValuesReader(bitWidth);
+ // do not read the length from the stream. v2 pages handle dividing the page bytes.
+ this.defColumn = new VectorizedRleValuesReader(bitWidth, false);
this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
- this.defColumn.initFromBuffer(
- this.pageValueCount, page.getDefinitionLevels().toByteArray());
+ this.defColumn.initFromPage(
+ this.pageValueCount, page.getDefinitionLevels().toInputStream());
try {
- initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
+ initDataReader(page.getDataEncoding(), page.getData().toInputStream());
} catch (IOException e) {
throw new IOException("could not read page " + page + " in col " + descriptor, e);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index 5b75f719339fb..c62dc3d86386e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -20,8 +20,9 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import org.apache.parquet.bytes.ByteBufferInputStream;
+import org.apache.parquet.io.ParquetDecodingException;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
-import org.apache.spark.unsafe.Platform;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
@@ -30,24 +31,18 @@
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
*/
public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader {
- private byte[] buffer;
- private int offset;
- private int bitOffset; // Only used for booleans.
- private ByteBuffer byteBuffer; // used to wrap the byte array buffer
+ private ByteBufferInputStream in = null;
- private static final boolean bigEndianPlatform =
- ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
+ // Only used for booleans.
+ private int bitOffset;
+ private byte currentByte = 0;
public VectorizedPlainValuesReader() {
}
@Override
- public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException {
- this.buffer = bytes;
- this.offset = offset + Platform.BYTE_ARRAY_OFFSET;
- if (bigEndianPlatform) {
- byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN);
- }
+ public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException {
+ this.in = in;
}
@Override
@@ -63,115 +58,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) {
}
}
+ private ByteBuffer getBuffer(int length) {
+ try {
+ return in.slice(length).order(ByteOrder.LITTLE_ENDIAN);
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read " + length + " bytes", e);
+ }
+ }
+
@Override
public final void readIntegers(int total, WritableColumnVector c, int rowId) {
- c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 4 * total;
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putIntsLittleEndian(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putInt(rowId + i, buffer.getInt());
+ }
+ }
}
@Override
public final void readLongs(int total, WritableColumnVector c, int rowId) {
- c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 8 * total;
+ int requiredBytes = total * 8;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putLongsLittleEndian(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putLong(rowId + i, buffer.getLong());
+ }
+ }
}
@Override
public final void readFloats(int total, WritableColumnVector c, int rowId) {
- c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 4 * total;
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putFloats(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putFloat(rowId + i, buffer.getFloat());
+ }
+ }
}
@Override
public final void readDoubles(int total, WritableColumnVector c, int rowId) {
- c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
- offset += 8 * total;
+ int requiredBytes = total * 8;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ if (buffer.hasArray()) {
+ int offset = buffer.arrayOffset() + buffer.position();
+ c.putDoubles(rowId, total, buffer.array(), offset);
+ } else {
+ for (int i = 0; i < total; i += 1) {
+ c.putDouble(rowId + i, buffer.getDouble());
+ }
+ }
}
@Override
public final void readBytes(int total, WritableColumnVector c, int rowId) {
- for (int i = 0; i < total; i++) {
- // Bytes are stored as a 4-byte little endian int. Just read the first byte.
- // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
- c.putByte(rowId + i, Platform.getByte(buffer, offset));
- offset += 4;
+ // Bytes are stored as a 4-byte little endian int. Just read the first byte.
+ // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
+ int requiredBytes = total * 4;
+ ByteBuffer buffer = getBuffer(requiredBytes);
+
+ for (int i = 0; i < total; i += 1) {
+ c.putByte(rowId + i, buffer.get());
+ // skip the next 3 bytes
+ buffer.position(buffer.position() + 3);
}
}
@Override
public final boolean readBoolean() {
- byte b = Platform.getByte(buffer, offset);
- boolean v = (b & (1 << bitOffset)) != 0;
+ // TODO: vectorize decoding and keep boolean[] instead of currentByte
+ if (bitOffset == 0) {
+ try {
+ currentByte = (byte) in.read();
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read a byte", e);
+ }
+ }
+
+ boolean v = (currentByte & (1 << bitOffset)) != 0;
bitOffset += 1;
if (bitOffset == 8) {
bitOffset = 0;
- offset++;
}
return v;
}
@Override
public final int readInteger() {
- int v = Platform.getInt(buffer, offset);
- if (bigEndianPlatform) {
- v = java.lang.Integer.reverseBytes(v);
- }
- offset += 4;
- return v;
+ return getBuffer(4).getInt();
}
@Override
public final long readLong() {
- long v = Platform.getLong(buffer, offset);
- if (bigEndianPlatform) {
- v = java.lang.Long.reverseBytes(v);
- }
- offset += 8;
- return v;
+ return getBuffer(8).getLong();
}
@Override
public final byte readByte() {
- return (byte)readInteger();
+ return (byte) readInteger();
}
@Override
public final float readFloat() {
- float v;
- if (!bigEndianPlatform) {
- v = Platform.getFloat(buffer, offset);
- } else {
- v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET);
- }
- offset += 4;
- return v;
+ return getBuffer(4).getFloat();
}
@Override
public final double readDouble() {
- double v;
- if (!bigEndianPlatform) {
- v = Platform.getDouble(buffer, offset);
- } else {
- v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET);
- }
- offset += 8;
- return v;
+ return getBuffer(8).getDouble();
}
@Override
public final void readBinary(int total, WritableColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
- int start = offset;
- offset += len;
- v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
+ ByteBuffer buffer = getBuffer(len);
+ if (buffer.hasArray()) {
+ v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len);
+ } else {
+ byte[] bytes = new byte[len];
+ buffer.get(bytes);
+ v.putByteArray(rowId + i, bytes);
+ }
}
}
@Override
public final Binary readBinary(int len) {
- Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len);
- offset += len;
- return result;
+ ByteBuffer buffer = getBuffer(len);
+ if (buffer.hasArray()) {
+ return Binary.fromConstantByteArray(
+ buffer.array(), buffer.arrayOffset() + buffer.position(), len);
+ } else {
+ byte[] bytes = new byte[len];
+ buffer.get(bytes);
+ return Binary.fromConstantByteArray(bytes);
+ }
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index fc7fa70c39419..fe3d31ae8e746 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import org.apache.parquet.Preconditions;
+import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.bytes.BytesUtils;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.column.values.bitpacking.BytePacker;
@@ -27,6 +28,9 @@
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
/**
* A values reader for Parquet's run-length encoded data. This is based off of the version in
* parquet-mr with these changes:
@@ -49,9 +53,7 @@ private enum MODE {
}
// Encoded data.
- private byte[] in;
- private int end;
- private int offset;
+ private ByteBufferInputStream in;
// bit/byte width of decoded data and utility to batch unpack them.
private int bitWidth;
@@ -70,45 +72,40 @@ private enum MODE {
// If true, the bit width is fixed. This decoder is used in different places and this also
// controls if we need to read the bitwidth from the beginning of the data stream.
private final boolean fixedWidth;
+ private final boolean readLength;
public VectorizedRleValuesReader() {
- fixedWidth = false;
+ this.fixedWidth = false;
+ this.readLength = false;
}
public VectorizedRleValuesReader(int bitWidth) {
- fixedWidth = true;
+ this.fixedWidth = true;
+ this.readLength = bitWidth != 0;
+ init(bitWidth);
+ }
+
+ public VectorizedRleValuesReader(int bitWidth, boolean readLength) {
+ this.fixedWidth = true;
+ this.readLength = readLength;
init(bitWidth);
}
@Override
- public void initFromPage(int valueCount, byte[] page, int start) {
- this.offset = start;
- this.in = page;
+ public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException {
+ this.in = in;
if (fixedWidth) {
- if (bitWidth != 0) {
+ // initialize for repetition and definition levels
+ if (readLength) {
int length = readIntLittleEndian();
- this.end = this.offset + length;
+ this.in = in.sliceStream(length);
}
} else {
- this.end = page.length;
- if (this.end != this.offset) init(page[this.offset++] & 255);
- }
- if (bitWidth == 0) {
- // 0 bit width, treat this as an RLE run of valueCount number of 0's.
- this.mode = MODE.RLE;
- this.currentCount = valueCount;
- this.currentValue = 0;
- } else {
- this.currentCount = 0;
+ // initialize for values
+ if (in.available() > 0) {
+ init(in.read());
+ }
}
- }
-
- // Initialize the reader from a buffer. This is used for the V2 page encoding where the
- // definition are in its own buffer.
- public void initFromBuffer(int valueCount, byte[] data) {
- this.offset = 0;
- this.in = data;
- this.end = data.length;
if (bitWidth == 0) {
// 0 bit width, treat this as an RLE run of valueCount number of 0's.
this.mode = MODE.RLE;
@@ -129,11 +126,6 @@ private void init(int bitWidth) {
this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth);
}
- @Override
- public int getNextOffset() {
- return this.end;
- }
-
@Override
public boolean readBoolean() {
return this.readInteger() != 0;
@@ -182,7 +174,7 @@ public void readIntegers(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -217,7 +209,7 @@ public void readBooleans(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -251,7 +243,7 @@ public void readBytes(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -285,7 +277,7 @@ public void readShorts(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -321,7 +313,7 @@ public void readLongs(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -355,7 +347,7 @@ public void readFloats(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -389,7 +381,7 @@ public void readDoubles(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -423,7 +415,7 @@ public void readBinarys(
WritableColumnVector c,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -462,7 +454,7 @@ public void readIntegers(
WritableColumnVector nulls,
int rowId,
int level,
- VectorizedValuesReader data) {
+ VectorizedValuesReader data) throws IOException {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -559,12 +551,12 @@ public Binary readBinary(int len) {
/**
* Reads the next varint encoded int.
*/
- private int readUnsignedVarInt() {
+ private int readUnsignedVarInt() throws IOException {
int value = 0;
int shift = 0;
int b;
do {
- b = in[offset++] & 255;
+ b = in.read();
value |= (b & 0x7F) << shift;
shift += 7;
} while ((b & 0x80) != 0);
@@ -574,35 +566,32 @@ private int readUnsignedVarInt() {
/**
* Reads the next 4 byte little endian int.
*/
- private int readIntLittleEndian() {
- int ch4 = in[offset] & 255;
- int ch3 = in[offset + 1] & 255;
- int ch2 = in[offset + 2] & 255;
- int ch1 = in[offset + 3] & 255;
- offset += 4;
+ private int readIntLittleEndian() throws IOException {
+ int ch4 = in.read();
+ int ch3 = in.read();
+ int ch2 = in.read();
+ int ch1 = in.read();
return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));
}
/**
* Reads the next byteWidth little endian int.
*/
- private int readIntLittleEndianPaddedOnBitWidth() {
+ private int readIntLittleEndianPaddedOnBitWidth() throws IOException {
switch (bytesWidth) {
case 0:
return 0;
case 1:
- return in[offset++] & 255;
+ return in.read();
case 2: {
- int ch2 = in[offset] & 255;
- int ch1 = in[offset + 1] & 255;
- offset += 2;
+ int ch2 = in.read();
+ int ch1 = in.read();
return (ch1 << 8) + ch2;
}
case 3: {
- int ch3 = in[offset] & 255;
- int ch2 = in[offset + 1] & 255;
- int ch1 = in[offset + 2] & 255;
- offset += 3;
+ int ch3 = in.read();
+ int ch2 = in.read();
+ int ch1 = in.read();
return (ch1 << 16) + (ch2 << 8) + (ch3 << 0);
}
case 4: {
@@ -619,32 +608,36 @@ private int ceil8(int value) {
/**
* Reads the next group.
*/
- private void readNextGroup() {
- int header = readUnsignedVarInt();
- this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED;
- switch (mode) {
- case RLE:
- this.currentCount = header >>> 1;
- this.currentValue = readIntLittleEndianPaddedOnBitWidth();
- return;
- case PACKED:
- int numGroups = header >>> 1;
- this.currentCount = numGroups * 8;
- int bytesToRead = ceil8(this.currentCount * this.bitWidth);
-
- if (this.currentBuffer.length < this.currentCount) {
- this.currentBuffer = new int[this.currentCount];
- }
- currentBufferIdx = 0;
- int valueIndex = 0;
- for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) {
- this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex);
- valueIndex += 8;
- }
- offset += bytesToRead;
- return;
- default:
- throw new ParquetDecodingException("not a valid mode " + this.mode);
+ private void readNextGroup() {
+ try {
+ int header = readUnsignedVarInt();
+ this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED;
+ switch (mode) {
+ case RLE:
+ this.currentCount = header >>> 1;
+ this.currentValue = readIntLittleEndianPaddedOnBitWidth();
+ return;
+ case PACKED:
+ int numGroups = header >>> 1;
+ this.currentCount = numGroups * 8;
+
+ if (this.currentBuffer.length < this.currentCount) {
+ this.currentBuffer = new int[this.currentCount];
+ }
+ currentBufferIdx = 0;
+ int valueIndex = 0;
+ while (valueIndex < this.currentCount) {
+ // values are bit packed 8 at a time, so reading bitWidth will always work
+ ByteBuffer buffer = in.slice(bitWidth);
+ this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex);
+ valueIndex += 8;
+ }
+ return;
+ default:
+ throw new ParquetDecodingException("not a valid mode " + this.mode);
+ }
+ } catch (IOException e) {
+ throw new ParquetDecodingException("Failed to read from input stream", e);
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 754c26579ff08..6fdadde628551 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -23,6 +23,7 @@
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.OffHeapMemoryBlock;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) {
@Override
protected UTF8String getBytesAsUTF8String(int rowId, int count) {
- return UTF8String.fromAddress(null, data + rowId, count);
+ return new UTF8String(new OffHeapMemoryBlock(data + rowId, count));
}
//
@@ -215,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) {
@Override
public void putShort(int rowId, short value) {
- Platform.putShort(null, data + 2 * rowId, value);
+ Platform.putShort(null, data + 2L * rowId, value);
}
@Override
public void putShorts(int rowId, int count, short value) {
- long offset = data + 2 * rowId;
+ long offset = data + 2L * rowId;
for (int i = 0; i < count; ++i, offset += 2) {
Platform.putShort(null, offset, value);
}
@@ -228,20 +229,20 @@ public void putShorts(int rowId, int count, short value) {
@Override
public void putShorts(int rowId, int count, short[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
- null, data + 2 * rowId, count * 2);
+ Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L,
+ null, data + 2L * rowId, count * 2L);
}
@Override
public void putShorts(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
- null, data + rowId * 2, count * 2);
+ null, data + rowId * 2L, count * 2L);
}
@Override
public short getShort(int rowId) {
if (dictionary == null) {
- return Platform.getShort(null, data + 2 * rowId);
+ return Platform.getShort(null, data + 2L * rowId);
} else {
return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId));
}
@@ -251,7 +252,7 @@ public short getShort(int rowId) {
public short[] getShorts(int rowId, int count) {
assert(dictionary == null);
short[] array = new short[count];
- Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2);
+ Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L);
return array;
}
@@ -261,12 +262,12 @@ public short[] getShorts(int rowId, int count) {
@Override
public void putInt(int rowId, int value) {
- Platform.putInt(null, data + 4 * rowId, value);
+ Platform.putInt(null, data + 4L * rowId, value);
}
@Override
public void putInts(int rowId, int count, int value) {
- long offset = data + 4 * rowId;
+ long offset = data + 4L * rowId;
for (int i = 0; i < count; ++i, offset += 4) {
Platform.putInt(null, offset, value);
}
@@ -274,24 +275,24 @@ public void putInts(int rowId, int count, int value) {
@Override
public void putInts(int rowId, int count, int[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4,
- null, data + 4 * rowId, count * 4);
+ Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L,
+ null, data + 4L * rowId, count * 4L);
}
@Override
public void putInts(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
- null, data + rowId * 4, count * 4);
+ null, data + rowId * 4L, count * 4L);
}
@Override
public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET,
- null, data + 4 * rowId, count * 4);
+ null, data + 4L * rowId, count * 4L);
} else {
int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
- long offset = data + 4 * rowId;
+ long offset = data + 4L * rowId;
for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) {
Platform.putInt(null, offset,
java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset)));
@@ -302,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex)
@Override
public int getInt(int rowId) {
if (dictionary == null) {
- return Platform.getInt(null, data + 4 * rowId);
+ return Platform.getInt(null, data + 4L * rowId);
} else {
return dictionary.decodeToInt(dictionaryIds.getDictId(rowId));
}
@@ -312,7 +313,7 @@ public int getInt(int rowId) {
public int[] getInts(int rowId, int count) {
assert(dictionary == null);
int[] array = new int[count];
- Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4);
+ Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L);
return array;
}
@@ -324,7 +325,7 @@ public int[] getInts(int rowId, int count) {
public int getDictId(int rowId) {
assert(dictionary == null)
: "A ColumnVector dictionary should not have a dictionary for itself.";
- return Platform.getInt(null, data + 4 * rowId);
+ return Platform.getInt(null, data + 4L * rowId);
}
//
@@ -333,12 +334,12 @@ public int getDictId(int rowId) {
@Override
public void putLong(int rowId, long value) {
- Platform.putLong(null, data + 8 * rowId, value);
+ Platform.putLong(null, data + 8L * rowId, value);
}
@Override
public void putLongs(int rowId, int count, long value) {
- long offset = data + 8 * rowId;
+ long offset = data + 8L * rowId;
for (int i = 0; i < count; ++i, offset += 8) {
Platform.putLong(null, offset, value);
}
@@ -346,24 +347,24 @@ public void putLongs(int rowId, int count, long value) {
@Override
public void putLongs(int rowId, int count, long[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8,
- null, data + 8 * rowId, count * 8);
+ Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L,
+ null, data + 8L * rowId, count * 8L);
}
@Override
public void putLongs(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
- null, data + rowId * 8, count * 8);
+ null, data + rowId * 8L, count * 8L);
}
@Override
public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET,
- null, data + 8 * rowId, count * 8);
+ null, data + 8L * rowId, count * 8L);
} else {
int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
- long offset = data + 8 * rowId;
+ long offset = data + 8L * rowId;
for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) {
Platform.putLong(null, offset,
java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset)));
@@ -374,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex)
@Override
public long getLong(int rowId) {
if (dictionary == null) {
- return Platform.getLong(null, data + 8 * rowId);
+ return Platform.getLong(null, data + 8L * rowId);
} else {
return dictionary.decodeToLong(dictionaryIds.getDictId(rowId));
}
@@ -384,7 +385,7 @@ public long getLong(int rowId) {
public long[] getLongs(int rowId, int count) {
assert(dictionary == null);
long[] array = new long[count];
- Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8);
+ Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L);
return array;
}
@@ -394,12 +395,12 @@ public long[] getLongs(int rowId, int count) {
@Override
public void putFloat(int rowId, float value) {
- Platform.putFloat(null, data + rowId * 4, value);
+ Platform.putFloat(null, data + rowId * 4L, value);
}
@Override
public void putFloats(int rowId, int count, float value) {
- long offset = data + 4 * rowId;
+ long offset = data + 4L * rowId;
for (int i = 0; i < count; ++i, offset += 4) {
Platform.putFloat(null, offset, value);
}
@@ -407,18 +408,18 @@ public void putFloats(int rowId, int count, float value) {
@Override
public void putFloats(int rowId, int count, float[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
- null, data + 4 * rowId, count * 4);
+ Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L,
+ null, data + 4L * rowId, count * 4L);
}
@Override
public void putFloats(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
- null, data + rowId * 4, count * 4);
+ null, data + rowId * 4L, count * 4L);
} else {
ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN);
- long offset = data + 4 * rowId;
+ long offset = data + 4L * rowId;
for (int i = 0; i < count; ++i, offset += 4) {
Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i)));
}
@@ -428,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) {
@Override
public float getFloat(int rowId) {
if (dictionary == null) {
- return Platform.getFloat(null, data + rowId * 4);
+ return Platform.getFloat(null, data + rowId * 4L);
} else {
return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId));
}
@@ -438,7 +439,7 @@ public float getFloat(int rowId) {
public float[] getFloats(int rowId, int count) {
assert(dictionary == null);
float[] array = new float[count];
- Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4);
+ Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L);
return array;
}
@@ -449,12 +450,12 @@ public float[] getFloats(int rowId, int count) {
@Override
public void putDouble(int rowId, double value) {
- Platform.putDouble(null, data + rowId * 8, value);
+ Platform.putDouble(null, data + rowId * 8L, value);
}
@Override
public void putDoubles(int rowId, int count, double value) {
- long offset = data + 8 * rowId;
+ long offset = data + 8L * rowId;
for (int i = 0; i < count; ++i, offset += 8) {
Platform.putDouble(null, offset, value);
}
@@ -462,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) {
@Override
public void putDoubles(int rowId, int count, double[] src, int srcIndex) {
- Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8,
- null, data + 8 * rowId, count * 8);
+ Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L,
+ null, data + 8L * rowId, count * 8L);
}
@Override
public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
- null, data + rowId * 8, count * 8);
+ null, data + rowId * 8L, count * 8L);
} else {
ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN);
- long offset = data + 8 * rowId;
+ long offset = data + 8L * rowId;
for (int i = 0; i < count; ++i, offset += 8) {
Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i)));
}
@@ -483,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
@Override
public double getDouble(int rowId) {
if (dictionary == null) {
- return Platform.getDouble(null, data + rowId * 8);
+ return Platform.getDouble(null, data + rowId * 8L);
} else {
return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId));
}
@@ -493,7 +494,7 @@ public double getDouble(int rowId) {
public double[] getDoubles(int rowId, int count) {
assert(dictionary == null);
double[] array = new double[count];
- Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8);
+ Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L);
return array;
}
@@ -503,26 +504,26 @@ public double[] getDoubles(int rowId, int count) {
@Override
public void putArray(int rowId, int offset, int length) {
assert(offset >= 0 && offset + length <= childColumns[0].capacity);
- Platform.putInt(null, lengthData + 4 * rowId, length);
- Platform.putInt(null, offsetData + 4 * rowId, offset);
+ Platform.putInt(null, lengthData + 4L * rowId, length);
+ Platform.putInt(null, offsetData + 4L * rowId, offset);
}
@Override
public int getArrayLength(int rowId) {
- return Platform.getInt(null, lengthData + 4 * rowId);
+ return Platform.getInt(null, lengthData + 4L * rowId);
}
@Override
public int getArrayOffset(int rowId) {
- return Platform.getInt(null, offsetData + 4 * rowId);
+ return Platform.getInt(null, offsetData + 4L * rowId);
}
// APIs dealing with ByteArrays
@Override
public int putByteArray(int rowId, byte[] value, int offset, int length) {
int result = arrayData().appendBytes(length, value, offset);
- Platform.putInt(null, lengthData + 4 * rowId, length);
- Platform.putInt(null, offsetData + 4 * rowId, result);
+ Platform.putInt(null, lengthData + 4L * rowId, length);
+ Platform.putInt(null, offsetData + 4L * rowId, result);
return result;
}
@@ -532,19 +533,19 @@ protected void reserveInternal(int newCapacity) {
int oldCapacity = (nulls == 0L) ? 0 : capacity;
if (isArray() || type instanceof MapType) {
this.lengthData =
- Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4);
+ Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L);
this.offsetData =
- Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4);
+ Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L);
} else if (type instanceof ByteType || type instanceof BooleanType) {
this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity);
} else if (type instanceof ShortType) {
- this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L);
} else if (type instanceof IntegerType || type instanceof FloatType ||
type instanceof DateType || DecimalType.is32BitDecimalType(type)) {
- this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L);
} else if (type instanceof LongType || type instanceof DoubleType ||
DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) {
- this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8);
+ this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L);
} else if (childColumns != null) {
// Nothing to store.
} else {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 23dcc104e67c4..577eab6ed14c8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) {
@Override
public void putShorts(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData,
- Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2);
+ Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L);
}
@Override
@@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) {
@Override
public void putInts(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData,
- Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4);
+ Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L);
}
@Override
@@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) {
@Override
public void putLongs(int rowId, int count, byte[] src, int srcIndex) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData,
- Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8);
+ Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L);
}
@Override
@@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) {
public void putFloats(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData,
- Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
+ Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L);
} else {
ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < count; ++i) {
@@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) {
public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
if (!bigEndianPlatform) {
Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData,
- Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8);
+ Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L);
} else {
ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < count; ++i) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 5275e4a91eac0..b0e119d658cb4 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -81,7 +81,9 @@ public void close() {
}
public void reserve(int requiredCapacity) {
- if (requiredCapacity > capacity) {
+ if (requiredCapacity < 0) {
+ throwUnsupportedException(requiredCapacity, null);
+ } else if (requiredCapacity > capacity) {
int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L);
if (requiredCapacity <= newCapacity) {
try {
@@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) {
}
private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
- String message = "Cannot reserve additional contiguous bytes in the vectorized reader " +
- "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " +
- "vectorized reader, or increase the vectorized reader batch size. For parquet file " +
- "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " +
- SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " +
- SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " +
- SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + ".";
+ String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" +
+ (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") +
+ "). As a workaround, you can reduce the vectorized reader batch size, or disable the " +
+ "vectorized reader. For parquet file format, refer to " +
+ SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() +
+ " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() +
+ ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " +
+ "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() +
+ " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() +
+ ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + ".";
throw new RuntimeException(message, cause);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java
similarity index 92%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java
index 0c1d5d1a9577a..7df5a451ae5f3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java
@@ -15,13 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2.reader;
+package org.apache.spark.sql.sources.v2;
import java.util.Optional;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader;
import org.apache.spark.sql.types.StructType;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
index c32053580f016..83df3be747085 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
@@ -17,16 +17,61 @@
package org.apache.spark.sql.sources.v2;
+import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.Stream;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.spark.annotation.InterfaceStability;
/**
* An immutable string-to-string map in which keys are case-insensitive. This is used to represent
* data source options.
+ *
+ * Each data source implementation can define its own options and teach its users how to set them.
+ * Spark doesn't have any restrictions about what options a data source should or should not have.
+ * Instead Spark defines some standard options that data sources can optionally adopt. It's possible
+ * that some options are very common and many data sources use them. However different data
+ * sources may define the common options(key and meaning) differently, which is quite confusing to
+ * end users.
+ *
+ * The standard options defined by Spark:
+ *
+ *
+ *
Option key
+ *
Option value
+ *
+ *
+ *
path
+ *
A path string of the data files/directories, like
+ * path1, /absolute/file2, path3/*. The path can
+ * either be relative or absolute, points to either file or directory, and can contain
+ * wildcards. This option is commonly used by file-based data sources.
+ *
+ *
+ *
paths
+ *
A JSON array style paths string of the data files/directories, like
+ * ["path1", "/absolute/file2"]. The format of each path is same as the
+ * path option, plus it should follow JSON string literal format, e.g. quotes
+ * should be escaped, pa\"th means pa"th.
+ *
+ *
+ *
+ *
table
+ *
A table name string representing the table name directly without any interpretation.
+ * For example, db.tbl means a table called db.tbl, not a table called tbl
+ * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
+ *
+ *
+ *
database
+ *
A database name string representing the database name directly without any
+ * interpretation, which is very similar to the table name option.
+ *
+ *
*/
@InterfaceStability.Evolving
public class DataSourceOptions {
@@ -97,4 +142,59 @@ public double getDouble(String key, double defaultValue) {
return keyLowerCasedMap.containsKey(lcaseKey) ?
Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
}
+
+ /**
+ * The option key for singular path.
+ */
+ public static final String PATH_KEY = "path";
+
+ /**
+ * The option key for multiple paths.
+ */
+ public static final String PATHS_KEY = "paths";
+
+ /**
+ * The option key for table name.
+ */
+ public static final String TABLE_KEY = "table";
+
+ /**
+ * The option key for database name.
+ */
+ public static final String DATABASE_KEY = "database";
+
+ /**
+ * Returns all the paths specified by both the singular path option and the multiple
+ * paths option.
+ */
+ public String[] paths() {
+ String[] singularPath =
+ get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]);
+ Optional pathsStr = get(PATHS_KEY);
+ if (pathsStr.isPresent()) {
+ ObjectMapper objectMapper = new ObjectMapper();
+ try {
+ String[] paths = objectMapper.readValue(pathsStr.get(), String[].class);
+ return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new);
+ } catch (IOException e) {
+ return singularPath;
+ }
+ } else {
+ return singularPath;
+ }
+ }
+
+ /**
+ * Returns the value of the table name option.
+ */
+ public Optional tableName() {
+ return get(TABLE_KEY);
+ }
+
+ /**
+ * Returns the value of the database name option.
+ */
+ public Optional databaseName() {
+ return get(DATABASE_KEY);
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
similarity index 89%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
index 5e8f0c0dafdcf..7f4a2c9593c76 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java
@@ -15,13 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2.reader;
+package org.apache.spark.sql.sources.v2;
import java.util.Optional;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader;
import org.apache.spark.sql.types.StructType;
@@ -36,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 {
* streaming query.
*
* The execution engine will create a micro-batch reader at the start of a streaming query,
- * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and
+ * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and
* then call stop() when the execution is complete. Note that a single query may have multiple
* executions due to restart or failure recovery.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
index 0ea4dc6b5def3..b2526ded53d92 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
@@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 {
/**
* Creates a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param options the options for the returned data source reader, which is an immutable
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
index 3801402268af1..f31659904cc53 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
@@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 {
/**
* Create a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param schema the full schema of this data source reader. Full schema usually maps to the
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java
similarity index 93%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java
index 1c0e2e12f8d51..a77b01497269e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2.writer;
+package org.apache.spark.sql.sources.v2;
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.execution.streaming.BaseStreamingSink;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
-import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter;
import org.apache.spark.sql.streaming.OutputMode;
import org.apache.spark.sql.types.StructType;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
index cab56453816cc..83aeec0c47853 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
@@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 {
* Creates an optional {@link DataSourceWriter} to save the data to this data source. Data
* sources can return None if there is no writing needed to be done according to the save mode.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param jobId A unique string for the writing job. It's possible that there are many writing
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
similarity index 53%
rename from common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
index 74ebc87dc978c..dcb87715d0b6f 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
@@ -15,40 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.memory;
+package org.apache.spark.sql.sources.v2.reader;
-import javax.annotation.Nullable;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset;
/**
- * A memory location. Tracked either by a memory address (with off-heap allocation),
- * or by an offset from a JVM object (in-heap allocation).
+ * A mix-in interface for {@link InputPartition}. Continuous input partitions can
+ * implement this interface to provide creating {@link InputPartitionReader} with particular offset.
*/
-public class MemoryLocation {
-
- @Nullable
- Object obj;
-
- long offset;
-
- public MemoryLocation(@Nullable Object obj, long offset) {
- this.obj = obj;
- this.offset = offset;
- }
-
- public MemoryLocation() {
- this(null, 0);
- }
-
- public void setObjAndOffset(Object newObj, long newOffset) {
- this.obj = newObj;
- this.offset = newOffset;
- }
-
- public final Object getBaseObject() {
- return obj;
- }
-
- public final long getBaseOffset() {
- return offset;
- }
+@InterfaceStability.Evolving
+public interface ContinuousInputPartition extends InputPartition {
+ /**
+ * Create an input partition reader with particular offset as its startOffset.
+ *
+ * @param offset offset want to set as the input partition reader's startOffset.
+ */
+ InputPartitionReader createContinuousReader(PartitionOffset offset);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
index a470bccc5aad2..36a3e542b5a11 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
@@ -31,8 +31,8 @@
* {@link ReadSupport#createReader(DataSourceOptions)} or
* {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}.
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan
- * logic is delegated to {@link DataReaderFactory}s that are returned by
- * {@link #createDataReaderFactories()}.
+ * logic is delegated to {@link InputPartition}s, which are returned by
+ * {@link #planInputPartitions()}.
*
* There are mainly 3 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
@@ -45,8 +45,8 @@
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
- * If an exception was throw when applying any of these query optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these query optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* Spark first applies all operator push-down optimizations that this data source supports. Then
* Spark collects information this data source reported for further optimizations. Finally Spark
@@ -59,22 +59,22 @@ public interface DataSourceReader {
* Returns the actual schema of this data source reader, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
StructType readSchema();
/**
- * Returns a list of reader factories. Each factory is responsible for creating a data reader to
- * output data for one RDD partition. That means the number of factories returned here is same as
- * the number of RDD partitions this scan outputs.
+ * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for
+ * creating a data reader to output data of one RDD partition. The number of input partitions
+ * returned here is the same as the number of RDD partitions this scan outputs.
*
* Note that, this may not be a full scan if the data source reader mixes in other optimization
* interfaces like column pruning, filter push-down, etc. These optimizations are applied before
* Spark issues the scan request.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
- List> createDataReaderFactories();
+ List> planInputPartitions();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
similarity index 60%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
index 32e98e8f5d8bd..f2038d0de3ffe 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
@@ -22,29 +22,30 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
- * responsible for creating the actual data reader. The relationship between
- * {@link DataReaderFactory} and {@link DataReader}
+ * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is
+ * responsible for creating the actual data reader of one RDD partition.
+ * The relationship between {@link InputPartition} and {@link InputPartitionReader}
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
*
- * Note that, the reader factory will be serialized and sent to executors, then the data reader
- * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be
- * serializable and {@link DataReader} doesn't need to be.
+ * Note that {@link InputPartition}s will be serialized and sent to executors, then
+ * {@link InputPartitionReader}s will be created on executors to do the actual reading. So
+ * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to
+ * be.
*/
@InterfaceStability.Evolving
-public interface DataReaderFactory extends Serializable {
+public interface InputPartition extends Serializable {
/**
- * The preferred locations where the data reader returned by this reader factory can run faster,
- * but Spark does not guarantee to run the data reader on these locations.
+ * The preferred locations where the input partition reader returned by this partition can run
+ * faster, but Spark does not guarantee to run the input partition reader on these locations.
* The implementations should make sure that it can be run on any location.
* The location is a string representing the host name.
*
* Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in
- * the returned locations. By default this method returns empty string array, which means this
- * task has no location preference.
+ * the returned locations. The default return value is empty string array, which means this
+ * input partition's reader has no location preference.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
default String[] preferredLocations() {
@@ -52,10 +53,10 @@ default String[] preferredLocations() {
}
/**
- * Returns a data reader to do the actual reading work.
+ * Returns an input partition reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
- DataReader createDataReader();
+ InputPartitionReader createPartitionReader();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
similarity index 80%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
index bb9790a1c819e..33fa7be4c1b20 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
@@ -23,15 +23,15 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
- * outputting data for a RDD partition.
+ * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is
+ * responsible for outputting data for a RDD partition.
*
- * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
- * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
- * readers that mix in {@link SupportsScanUnsafeRow}.
+ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input
+ * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input
+ * partition readers that mix in {@link SupportsScanUnsafeRow}.
*/
@InterfaceStability.Evolving
-public interface DataReader extends Closeable {
+public interface InputPartitionReader extends Closeable {
/**
* Proceed to next record, returns false if there is no more records.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java
index 98224102374aa..4543c143a9aca 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java
@@ -34,12 +34,21 @@
public interface SupportsPushDownCatalystFilters extends DataSourceReader {
/**
- * Pushes down filters, and returns unsupported filters.
+ * Pushes down filters, and returns filters that need to be evaluated after scanning.
*/
Expression[] pushCatalystFilters(Expression[] filters);
/**
- * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}.
+ * Returns the catalyst filters that are pushed to the data source via
+ * {@link #pushCatalystFilters(Expression[])}.
+ *
+ * There are 3 kinds of filters:
+ * 1. pushable filters which don't need to be evaluated again after scanning.
+ * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet
+ * row group filter.
+ * 3. non-pushable filters.
+ * Both case 1 and 2 should be considered as pushed filters and should be returned by this method.
+ *
* It's possible that there is no filters in the query and
* {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for
* this case.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
index f35c711b0387a..b6a90a3d0b681 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
@@ -32,12 +32,20 @@
public interface SupportsPushDownFilters extends DataSourceReader {
/**
- * Pushes down filters, and returns unsupported filters.
+ * Pushes down filters, and returns filters that need to be evaluated after scanning.
*/
Filter[] pushFilters(Filter[] filters);
/**
- * Returns the filters that are pushed in {@link #pushFilters(Filter[])}.
+ * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}.
+ *
+ * There are 3 kinds of filters:
+ * 1. pushable filters which don't need to be evaluated again after scanning.
+ * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet
+ * row group filter.
+ * 3. non-pushable filters.
+ * Both case 1 and 2 should be considered as pushed filters and should be returned by this method.
+ *
* It's possible that there is no filters in the query and {@link #pushFilters(Filter[])}
* is never called, empty array should be returned for this case.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
index 5405a916951b8..6b60da7c4dc1d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -23,6 +23,9 @@
/**
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report data partitioning and try to avoid shuffle at Spark side.
+ *
+ * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid
+ * adding a shuffle even if the reader does not implement this interface.
*/
@InterfaceStability.Evolving
public interface SupportsReportPartitioning extends DataSourceReader {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
index 11bb13fd3b211..926396414816c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
@@ -22,6 +22,10 @@
/**
* A mix in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to report statistics to Spark.
+ *
+ * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader.
+ * Implementations that return more accurate statistics based on pushed operators will not improve
+ * query performance until the planner can push operators before getting stats.
*/
@InterfaceStability.Evolving
public interface SupportsReportStatistics extends DataSourceReader {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
index 2e5cfa78511f0..0faf81db24605 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
@@ -30,22 +30,22 @@
@InterfaceStability.Evolving
public interface SupportsScanColumnarBatch extends DataSourceReader {
@Override
- default List> createDataReaderFactories() {
+ default List> planInputPartitions() {
throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanColumnarBatch.");
+ "planInputPartitions not supported by default within SupportsScanColumnarBatch.");
}
/**
- * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data
+ * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data
* in batches.
*/
- List> createBatchDataReaderFactories();
+ List> planBatchInputPartitions();
/**
* Returns true if the concrete data source reader can read data in batch according to the scan
* properties like required columns, pushes filters, etc. It's possible that the implementation
* can only support some certain columns with certain types. Users can overwrite this method and
- * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions.
+ * {@link #planInputPartitions()} to fallback to normal read path under some conditions.
*/
default boolean enableBatchRead() {
return true;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
index 9cd749e8e4ce9..f2220f6d31093 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
@@ -33,14 +33,14 @@
public interface SupportsScanUnsafeRow extends DataSourceReader {
@Override
- default List> createDataReaderFactories() {
+ default List> planInputPartitions() {
throw new IllegalStateException(
- "createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
+ "planInputPartitions not supported by default within SupportsScanUnsafeRow");
}
/**
- * Similar to {@link DataSourceReader#createDataReaderFactories()},
+ * Similar to {@link DataSourceReader#planInputPartitions()},
* but returns data in unsafe row format.
*/
- List> createUnsafeRowReaderFactories();
+ List> planUnsafeInputPartitions();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
index 2d0ee50212b56..38ca5fc6387b2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
@@ -18,12 +18,12 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
* A concrete implementation of {@link Distribution}. Represents a distribution where records that
* share the same values for the {@link #clusteredColumns} will be produced by the same
- * {@link DataReader}.
+ * {@link InputPartitionReader}.
*/
@InterfaceStability.Evolving
public class ClusteredDistribution implements Distribution {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
index f6b111fdf220d..5e32ba6952e1c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
@@ -18,13 +18,14 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
* An interface to represent data distribution requirement, which specifies how the records should
- * be distributed among the data partitions(one {@link DataReader} outputs data for one partition).
+ * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one
+ * partition).
* Note that this interface has nothing to do with the data ordering inside one
- * partition(the output records of a single {@link DataReader}).
+ * partition(the output records of a single {@link InputPartitionReader}).
*
* The instance of this interface is created and provided by Spark, then consumed by
* {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
index 309d9e5de0a0f..f460f6bfe3bb9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
/**
@@ -31,7 +31,7 @@
public interface Partitioning {
/**
- * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs.
+ * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs.
*/
int numPartitions();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
similarity index 84%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
index 47d26440841fd..7b0ba0bbdda90 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java
@@ -18,13 +18,13 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
/**
- * A variation on {@link DataReader} for use with streaming in continuous processing mode.
+ * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode.
*/
@InterfaceStability.Evolving
-public interface ContinuousDataReader extends DataReader {
+public interface ContinuousInputPartitionReader extends InputPartitionReader {
/**
* Get the offset of the current record, or the start offset if no records have been read.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
index 7fe7f00ac2fa8..6e960bedf8020 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
@@ -27,7 +27,7 @@
* A mix-in interface for {@link DataSourceReader}. Data source readers can implement this
* interface to allow reading in a continuous processing mode stream.
*
- * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}.
+ * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}.
*
* Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with
* DataSource V1 APIs. This extension will be removed once we get rid of V1 completely.
@@ -35,8 +35,8 @@
@InterfaceStability.Evolving
public interface ContinuousReader extends BaseStreamingSource, DataSourceReader {
/**
- * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each
- * partition to a single global offset.
+ * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances
+ * for each partition to a single global offset.
*/
Offset mergeOffsets(PartitionOffset[] offsets);
@@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader
Offset deserializeOffset(String json);
/**
- * Set the desired start offset for reader factories created from this reader. The scan will
+ * Set the desired start offset for partitions created from this reader. The scan will
* start from the first record after the provided offset, or from an implementation-defined
* inferred starting point if no offset is provided.
*/
@@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader
Offset getStartOffset();
/**
- * The execution engine will call this method in every epoch to determine if new reader
- * factories need to be generated, which may be required if for example the underlying
+ * The execution engine will call this method in every epoch to determine if new input
+ * partitions need to be generated, which may be required if for example the underlying
* source system has had partitions added or removed.
*
* If true, the query will be shut down and restarted with a new reader.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
index 67ebde30d61a9..0159c731762d9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java
@@ -33,7 +33,7 @@
@InterfaceStability.Evolving
public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource {
/**
- * Set the desired offset range for reader factories created from this reader. Reader factories
+ * Set the desired offset range for input partitions created from this reader. Partition readers
* will generate only data within (`start`, `end`]; that is, from the first record after `start`
* to the record with offset `end`.
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
index 52324b3792b8a..0030a9f05dba7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
@@ -21,6 +21,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
+import org.apache.spark.sql.sources.v2.StreamWriteSupport;
import org.apache.spark.sql.sources.v2.WriteSupport;
import org.apache.spark.sql.streaming.OutputMode;
import org.apache.spark.sql.types.StructType;
@@ -33,8 +34,8 @@
* It can mix in various writing optimization interfaces to speed up the data saving. The actual
* writing logic is delegated to {@link DataWriter}.
*
- * If an exception was throw when applying any of these writing optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these writing optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* The writing procedure is:
* 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the
@@ -57,11 +58,21 @@ public interface DataSourceWriter {
/**
* Creates a writer factory which will be serialized and sent to executors.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
DataWriterFactory createWriterFactory();
+ /**
+ * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for
+ * each task commits.
+ *
+ * @return true if commit coordinator should be used, false otherwise.
+ */
+ default boolean useCommitCoordinator() {
+ return true;
+ }
+
/**
* Handles a commit message on receiving from a successful data writer.
*
@@ -78,10 +89,11 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
* failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination
* is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it.
*
- * Note that, one partition may have multiple committed data writers because of speculative tasks.
- * Spark will pick the first successful one and get its commit message. Implementations should be
- * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data
- * writer can commit, or have a way to clean up the data of already-committed writers.
+ * Note that speculative execution may cause multiple tasks to run for a partition. By default,
+ * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can
+ * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple
+ * attempts may have committed successfully and one successful commit message per task will be
+ * passed to this commit method. The remaining commit messages are ignored by Spark.
*/
void commit(WriterCommitMessage[] messages);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
index 53941a89ba94e..39bf458298862 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
@@ -22,7 +22,7 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is
+ * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is
* responsible for writing data for an input RDD partition.
*
* One Spark task has one exclusive data writer, so there is no thread-safe concern.
@@ -31,13 +31,17 @@
* the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will
* not be processed. If all records are successfully written, {@link #commit()} is called.
*
+ * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle
+ * is over and Spark will not use it again.
+ *
* If this data writer succeeds(all records are successfully written and {@link #commit()}
* succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to
* {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data
* writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an
- * exception will be sent to the driver side, and Spark will retry this writing task for some times,
- * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`,
- * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail.
+ * exception will be sent to the driver side, and Spark may retry this writing task a few times.
+ * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a
+ * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
+ * when the configured number of retries is exhausted.
*
* Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task
* takes too long to finish. Different from retried tasks, which are launched one by one after the
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
index ea95442511ce5..7527bcc0c4027 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
@@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable {
/**
* Returns a data writer to do the actual writing work.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param partitionId A unique id of the RDD partition that the returned writer will process.
@@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable {
* same task id but different attempt number, which means there are multiple
* tasks with the same task id running at the same time. Implementations can
* use this attempt number to distinguish writers of different task attempts.
+ * @param epochId A monotonically increasing id for streaming queries that are split in to
+ * discrete periods of execution. For non-streaming queries,
+ * this ID will always be 0.
*/
- DataWriter createDataWriter(int partitionId, int attemptNumber);
+ DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java
index 4913341bd505d..a316b2a4c1d82 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java
@@ -23,11 +23,10 @@
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
/**
- * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and
- * aborts relative to an epoch ID determined by the execution engine.
+ * A {@link DataSourceWriter} for use with structured streaming.
*
- * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs,
- * and so must reset any internal state after a successful commit.
+ * Streaming queries are divided into intervals of data called epochs, with a monotonically
+ * increasing numeric ID. This writer handles commits and aborts for each successive epoch.
*/
@InterfaceStability.Evolving
public interface StreamWriter extends DataSourceWriter {
@@ -39,21 +38,21 @@ public interface StreamWriter extends DataSourceWriter {
* If this method fails (by throwing an exception), this writing job is considered to have been
* failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}.
*
- * To support exactly-once processing, writer implementations should ensure that this method is
- * idempotent. The execution engine may call commit() multiple times for the same epoch
- * in some circumstances.
+ * The execution engine may call commit() multiple times for the same epoch in some circumstances.
+ * To support exactly-once data semantics, implementations must ensure that multiple commits for
+ * the same epoch are idempotent.
*/
void commit(long epochId, WriterCommitMessage[] messages);
/**
- * Aborts this writing job because some data writers are failed and keep failing when retry, or
+ * Aborts this writing job because some data writers are failed and keep failing when retried, or
* the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails.
*
* If this method fails (by throwing an exception), the underlying data source may require manual
* cleanup.
*
- * Unless the abort is triggered by the failure of commit, the given messages should have some
- * null slots as there maybe only a few data writers that are committed before the abort
+ * Unless the abort is triggered by the failure of commit, the given messages will have some
+ * null slots, as there may be only a few data writers that were committed before the abort
* happens, or some data writers were committed but their commit messages haven't reached the
* driver when the abort is triggered. So this is just a "best effort" for data sources to
* clean up the data left by data writers.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index f8e37e995a17f..227a16f7e69e9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -25,6 +25,7 @@
import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.execution.arrow.ArrowUtils;
import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.memory.OffHeapMemoryBlock;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) {
if (stringResult.isSet == 0) {
return null;
} else {
- return UTF8String.fromAddress(null,
+ return new UTF8String(new OffHeapMemoryBlock(
stringResult.buffer.memoryAddress() + stringResult.start,
- stringResult.end - stringResult.start);
+ stringResult.end - stringResult.start
+ ));
}
}
}
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 0259c774bbf4a..1b37905543b4e 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
-org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
-org.apache.spark.sql.execution.streaming.RateSourceProvider
-org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
+org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
+org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 92988680871a4..4eee3de5f7d4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.JavaConverters._
import scala.language.implicitConversions
import org.apache.spark.annotation.InterfaceStability
@@ -103,7 +104,7 @@ class TypedColumn[-T, U](
*
* {{{
* df("columnName") // On a specific `df` DataFrame.
- * col("columnName") // A generic column no yet associated with a DataFrame.
+ * col("columnName") // A generic column not yet associated with a DataFrame.
* col("columnName.field") // Extracting a struct field
* col("`a.column.with.dots`") // Escape `.` in column names.
* $"columnName" // Scala short hand for a named column.
@@ -780,12 +781,54 @@ class Column(val expr: Expression) extends Logging {
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the evaluated values of the arguments.
*
+ * Note: Since the type of the elements in the list are inferred only during the run time,
+ * the elements will be "up-casted" to the most common type for comparison.
+ * For eg:
+ * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
+ * comparison will look like "String vs String".
+ * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
+ * comparison will look like "Double vs Double"
+ *
* @group expr_ops
* @since 1.5.0
*/
@scala.annotation.varargs
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) }
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the provided collection.
+ *
+ * Note: Since the type of the elements in the collection are inferred only during the run time,
+ * the elements will be "up-casted" to the most common type for comparison.
+ * For eg:
+ * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
+ * comparison will look like "String vs String".
+ * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
+ * comparison will look like "Double vs Double"
+ *
+ * @group expr_ops
+ * @since 2.4.0
+ */
+ def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the provided collection.
+ *
+ * Note: Since the type of the elements in the collection are inferred only during the run time,
+ * the elements will be "up-casted" to the most common type for comparison.
+ * For eg:
+ * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
+ * comparison will look like "String vs String".
+ * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
+ * comparison will look like "Double vs Double"
+ *
+ * @group java_expr_ops
+ * @since 2.4.0
+ */
+ def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)
+
/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
*
@@ -1083,10 +1126,10 @@ class Column(val expr: Expression) extends Logging {
* and null values return before non-null values.
* {{{
* // Scala: sort a DataFrame by age column in ascending order and null values appearing first.
- * df.sort(df("age").asc_nulls_last)
+ * df.sort(df("age").asc_nulls_first)
*
* // Java
- * df.sort(df.col("age").asc_nulls_last());
+ * df.sort(df.col("age").asc_nulls_first());
* }}}
*
* @group expr_ops
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index fcaf8d618c168..ec9352a7fa055 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -21,6 +21,9 @@ import java.util.{Locale, Properties}
import scala.collection.JavaConverters._
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.univocity.parsers.csv.CsvParser
+
import org.apache.spark.Partition
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.api.java.JavaRDD
@@ -34,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
-import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -171,7 +174,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @since 1.4.0
*/
def load(path: String): DataFrame = {
- option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)`
+ // force invocation of `load(...varargs...)`
+ option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*)
}
/**
@@ -189,39 +193,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- val ds = cls.newInstance()
- val options = new DataSourceOptions((extraOptions ++
- DataSourceV2Utils.extractSessionConfigs(
- ds = ds.asInstanceOf[DataSourceV2],
- conf = sparkSession.sessionState.conf)).asJava)
-
- // Streaming also uses the data source V2 API. So it may be that the data source implements
- // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
- // the dataframe as a v1 source.
- val reader = (ds, userSpecifiedSchema) match {
- case (ds: ReadSupportWithSchema, Some(schema)) =>
- ds.createReader(schema, options)
-
- case (ds: ReadSupport, None) =>
- ds.createReader(options)
-
- case (ds: ReadSupportWithSchema, None) =>
- throw new AnalysisException(s"A schema needs to be specified when using $ds.")
-
- case (ds: ReadSupport, Some(schema)) =>
- val reader = ds.createReader(options)
- if (reader.readSchema() != schema) {
- throw new AnalysisException(s"$ds does not allow user-specified schemas.")
- }
- reader
-
- case _ => null // fall back to v1
- }
-
- if (reader == null) {
- loadV1Source(paths: _*)
+ val ds = cls.newInstance().asInstanceOf[DataSourceV2]
+ if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) {
+ val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
+ ds = ds, conf = sparkSession.sessionState.conf)
+ val pathsOption = {
+ val objectMapper = new ObjectMapper()
+ DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
+ }
+ Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
+ ds, extraOptions.toMap ++ sessionOptions + pathsOption,
+ userSpecifiedSchema = userSpecifiedSchema))
} else {
- Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+ loadV1Source(paths: _*)
}
} else {
loadV1Source(paths: _*)
@@ -274,7 +258,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included. "fetchsize" can be used to control the
- * number of rows per fetch.
+ * number of rows per fetch and "queryTimeout" can be used to wait
+ * for a Statement object to execute to the given number of seconds.
* @since 1.4.0
*/
def jdbc(
@@ -368,12 +353,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
*
- *
`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
- * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
- * in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord`
- * field in an output schema.
+ *
`PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
+ * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To
+ * keep corrupt records, an user can set a string type field named
+ * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the
+ * field, it drops corrupt records during parsing. When inferring a schema, it implicitly
+ * adds a `columnNameOfCorruptRecord` field in an output schema.
*
`DROPMALFORMED` : ignores the whole corrupted records.
*
`FAILFAST` : throws an exception when it meets corrupted records.
*
@@ -389,6 +374,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ *
`encoding` (by default it is not set): allows to forcibly set one of standard basic
+ * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding
+ * is not specified and `multiLine` is set to `true`, it will be detected automatically.
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
+ *
`samplingRatio` (default is 1.0): defines fraction of input JSON objects used
+ * for schema inferring.
+ *
`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
@@ -483,12 +477,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* it determines the columns as string types and it reads only the first line to determine the
* names and the number of fields.
*
+ * If the enforceSchema is set to `false`, only the CSV header in the first line is checked
+ * to conform specified or inferred schema.
+ *
* @param csvDataset input Dataset with one CSV row per record
* @since 2.2.0
*/
def csv(csvDataset: Dataset[String]): DataFrame = {
val parsedOptions: CSVOptions = new CSVOptions(
extraOptions.toMap,
+ sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone)
val filteredLines: Dataset[String] =
CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions)
@@ -507,6 +505,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+ CSVDataSource.checkHeader(
+ firstLine,
+ new CsvParser(parsedOptions.asParserSettings),
+ actualSchema,
+ csvDataset.getClass.getCanonicalName,
+ parsedOptions.enforceSchema,
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
@@ -547,8 +552,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`comment` (default empty string): sets a single character used for skipping lines
* beginning with this character. By default, it is disabled.
*
`header` (default `false`): uses the first line as names of columns.
+ *
`enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema
+ * will be forcibly applied to datasource files, and headers in CSV files will be ignored.
+ * If the option is set to `false`, the schema will be validated against all headers in CSV files
+ * in the case when the `header` option is set to `true`. Field names in the schema
+ * and column names in CSV headers are checked by their positions taking into account
+ * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable
+ * the `enforceSchema` option to avoid incorrect results.
*
`inferSchema` (default `false`): infers the input schema automatically from data. It
* requires one extra pass over the data.
+ *
`samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
*
`ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading
* whitespaces from values being read should be skipped.
*
`ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing
@@ -573,12 +586,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*
`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing. It supports the following case-insensitive modes.
*
- *
`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ *
`PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
+ * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep
* corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
* in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. When a length of parsed CSV tokens is shorter than an expected length
- * of a schema, it sets `null` for extra fields.
+ * during parsing. A record with less/more tokens than schema is not a corrupted record to
+ * CSV. When it meets a record having fewer tokens than the length of the schema, sets
+ * `null` to extra fields. When the record has more tokens than the length of the schema,
+ * it drops extra tokens.
*
`DROPMALFORMED` : ignores the whole corrupted records.
*
`FAILFAST` : throws an exception when it meets corrupted records.
*
@@ -588,6 +603,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
*
`multiLine` (default `false`): parse one record, which may span multiple lines.
*
+ *
* @since 2.0.0
*/
@scala.annotation.varargs
@@ -668,14 +684,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* Loads text files and returns a `DataFrame` whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
*
- * You can set the following text-specific option(s) for reading text files:
- *
- *
`wholetext` ( default `false`): If true, read a file as a single row and not split by "\n".
- *
- *
- * By default, each line in the text files is a new row in the resulting DataFrame.
- *
- * Usage example:
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
* spark.read.text("/path/to/spark/README.md")
@@ -684,6 +693,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().text("/path/to/spark/README.md")
* }}}
*
+ * You can set the following text-specific option(s) for reading text files:
+ *
+ *
`wholetext` (default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
+ *
+ *
* @param paths input paths
* @since 1.6.0
*/
@@ -707,11 +724,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* If the directory structure of the text files contains partitioning information, those are
* ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
*
- * You can set the following textFile-specific option(s) for reading text files:
- *
- *
`wholetext` ( default `false`): If true, read a file as a single row and not split by "\n".
- *
- *
* By default, each line in the text files is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
@@ -721,6 +733,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* spark.read().textFile("/path/to/spark/README.md")
* }}}
*
+ * You can set the following textFile-specific option(s) for reading text files:
+ *
+ *
`wholetext` (default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
+ *
+ *
* @param paths input path
* @since 2.0.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ed7a9100cc7f1..90bea2d676e22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
private def getBucketSpec: Option[BucketSpec] = {
- if (sortColumnNames.isDefined) {
- require(numBuckets.isDefined, "sortBy must be used together with bucketBy")
+ if (sortColumnNames.isDefined && numBuckets.isEmpty) {
+ throw new AnalysisException("sortBy must be used together with bucketBy")
}
numBuckets.map { n =>
@@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
private def assertNotBucketed(operation: String): Unit = {
- if (numBuckets.isDefined || sortColumnNames.isDefined) {
- throw new AnalysisException(s"'$operation' does not support bucketing right now")
+ if (getBucketSpec.isDefined) {
+ if (sortColumnNames.isEmpty) {
+ throw new AnalysisException(s"'$operation' does not support bucketBy right now")
+ } else {
+ throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now")
+ }
}
}
@@ -518,6 +522,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*
`timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that
* indicates a timestamp format. Custom date formats follow the formats at
* `java.text.SimpleDateFormat`. This applies to timestamp type.
+ *
`encoding` (by default it is not set): specifies encoding (charset) of saved json
+ * files. If it is not set, the UTF-8 charset will be used.
+ *
`lineSep` (default `\n`): defines the line separator that should be used for writing.
*
*
* @since 1.4.0
@@ -587,6 +594,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*
`compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
* `snappy` and `deflate`).
+ *
`lineSep` (default `\n`): defines the line separator that should be used for writing.
*
*
* @since 1.6.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 0aee1d7be5788..f5526104690d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -196,7 +196,7 @@ class Dataset[T] private[sql](
}
// Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again.
- @transient private val planWithBarrier = AnalysisBarrier(logicalPlan)
+ @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan)
/**
* Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
@@ -231,16 +231,17 @@ class Dataset[T] private[sql](
}
/**
- * Compose the string representing rows for output
+ * Get rows represented in Sequence by specific truncate and vertical requirement.
*
- * @param _numRows Number of rows to show
+ * @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
- * @param vertical If set to true, prints output rows vertically (one line per column value).
+ * @param vertical If set to true, the rows to return do not need truncate.
*/
- private[sql] def showString(
- _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
- val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ private[sql] def getRows(
+ numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Seq[Seq[String]] = {
val newDf = toDF()
val castCols = newDf.logicalPlan.output.map { col =>
// Since binary types in top-level schema fields have a specific format to print,
@@ -251,14 +252,12 @@ class Dataset[T] private[sql](
Column(col).cast(StringType)
}
}
- val takeResult = newDf.select(castCols: _*).take(numRows + 1)
- val hasMoreData = takeResult.length > numRows
- val data = takeResult.take(numRows)
+ val data = newDf.select(castCols: _*).take(numRows + 1)
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ schema.fieldNames.toSeq +: data.map { row =>
row.toSeq.map { cell =>
val str = cell match {
case null => "null"
@@ -274,6 +273,26 @@ class Dataset[T] private[sql](
}
}: Seq[String]
}
+ }
+
+ /**
+ * Compose the string representing rows for output
+ *
+ * @param _numRows Number of rows to show
+ * @param truncate If set to more than 0, truncates strings to `truncate` characters and
+ * all cells will be aligned right.
+ * @param vertical If set to true, prints output rows vertically (one line per column value).
+ */
+ private[sql] def showString(
+ _numRows: Int,
+ truncate: Int = 20,
+ vertical: Boolean = false): String = {
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
+ val tmpRows = getRows(numRows, truncate, vertical)
+
+ val hasMoreData = tmpRows.length - 1 > numRows
+ val rows = tmpRows.take(numRows + 1)
val sb = new StringBuilder
val numCols = schema.fieldNames.length
@@ -291,31 +310,25 @@ class Dataset[T] private[sql](
}
}
+ val paddedRows = rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ if (truncate > 0) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }
+ }
+
// Create SeparateLine
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
// column names
- rows.head.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell, colWidths(i))
- } else {
- StringUtils.rightPad(cell, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
-
+ paddedRows.head.addString(sb, "|", "|", "|\n")
sb.append(sep)
// data
- rows.tail.foreach {
- _.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell.toString, colWidths(i))
- } else {
- StringUtils.rightPad(cell.toString, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
- }
-
+ paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
sb.append(sep)
} else {
// Extended display mode enabled
@@ -346,7 +359,7 @@ class Dataset[T] private[sql](
}
// Print a footer
- if (vertical && data.isEmpty) {
+ if (vertical && rows.tail.isEmpty) {
// In a vertical mode, print an empty row set explicitly
sb.append("(0 rows)\n")
} else if (hasMoreData) {
@@ -511,6 +524,16 @@ class Dataset[T] private[sql](
*/
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
+ /**
+ * Returns true if the `Dataset` is empty.
+ *
+ * @group basic
+ * @since 2.4.0
+ */
+ def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
+ plan.executeCollect().head.getLong(0) == 0
+ }
+
/**
* Returns true if this Dataset contains one or more sources that continuously
* return data as it arrives. A Dataset that reads data from a streaming source
@@ -1607,7 +1630,9 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
- def reduce(func: (T, T) => T): T = rdd.reduce(func)
+ def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
+ rdd.reduce(func)
+ }
/**
* :: Experimental ::
@@ -2933,7 +2958,7 @@ class Dataset[T] private[sql](
*/
def storageLevel: StorageLevel = {
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
- cachedData.cachedRepresentation.storageLevel
+ cachedData.cachedRepresentation.cacheBuilder.storageLevel
}.getOrElse(StorageLevel.NONE)
}
@@ -3187,27 +3212,41 @@ class Dataset[T] private[sql](
EvaluatePython.javaToPython(rdd)
}
- private[sql] def collectToPython(): Int = {
+ private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
- withNewExecutionId {
+ withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
- val iter = new SerDeUtil.AutoBatchedPickler(
- queryExecution.executedPlan.executeCollect().iterator.map(toJava))
+ val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+ plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}
+ private[sql] def getRowsToPython(
+ _numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Array[Any] = {
+ EvaluatePython.registerPicklers()
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+ val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
+ val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+ rows.iterator.map(toJava))
+ PythonRDD.serveIterator(iter, "serve-GetRows")
+ }
+
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
- private[sql] def collectAsArrowToPython(): Int = {
- withNewExecutionId {
- val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
+ private[sql] def collectAsArrowToPython(): Array[Any] = {
+ withAction("collectAsArrowToPython", queryExecution) { plan =>
+ val iter: Iterator[Array[Byte]] =
+ toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
}
}
- private[sql] def toPythonIterator(): Int = {
+ private[sql] def toPythonIterator(): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
}
@@ -3311,14 +3350,19 @@ class Dataset[T] private[sql](
}
/** Convert to an RDD of ArrowPayload byte arrays */
- private[sql] def toArrowPayload: RDD[ArrowPayload] = {
+ private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
- queryExecution.toRdd.mapPartitionsInternal { iter =>
+ plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toPayloadIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context)
}
}
+
+ // This is only used in tests, for now.
+ private[sql] def toArrowPayload: RDD[ArrowPayload] = {
+ toArrowPayload(queryExecution.executedPlan)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
index 86e02e98c01f3..b21c50af18433 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -20,10 +20,48 @@ package org.apache.spark.sql
import org.apache.spark.annotation.InterfaceStability
/**
- * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the
- * generated data to external systems. Each partition will use a new deserialized instance, so you
- * usually should do all the initialization (e.g. opening a connection or initiating a transaction)
- * in the `open` method.
+ * The abstract class for writing custom logic to process data generated by a query.
+ * This is often used to write the output of a streaming query to arbitrary storage systems.
+ * Any implementation of this base class will be used by Spark in the following way.
+ *
+ *
+ *
A single instance of this class is responsible of all the data generated by a single task
+ * in a query. In other words, one instance is responsible for processing one partition of the
+ * data generated in a distributed manner.
+ *
+ *
Any implementation of this class must be serializable because each task will get a fresh
+ * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that
+ * any initialization for writing data (e.g. opening a connection or starting a transaction)
+ * is done after the `open(...)` method has been called, which signifies that the task is
+ * ready to generate data.
+ *
+ *
The lifecycle of the methods are as follows.
+ *
+ *
+ * For each partition with `partitionId`:
+ * For each batch/epoch of streaming data (if its streaming query) with `epochId`:
+ * Method `open(partitionId, epochId)` is called.
+ * If `open` returns true:
+ * For each row in the partition and batch/epoch, method `process(row)` is called.
+ * Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
+ *
+ *
+ *
+ *
+ * Important points to note:
+ *
+ *
The `partitionId` and `epochId` can be used to deduplicate generated data when failures
+ * cause reprocessing of some input data. This depends on the execution mode of the query. If
+ * the streaming query is being executed in the micro-batch mode, then every partition
+ * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data.
+ * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data
+ * and achieve exactly-once guarantees. However, if the streaming query is being executed in the
+ * continuous mode, then this guarantee does not hold and therefore should not be used for
+ * deduplication.
+ *
+ *
The `close()` method will be called if `open()` method returns successfully (irrespective
+ * of the return value), except if the JVM crashes in the middle.
+ *
*
* Scala example:
* {{{
@@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability
* }
* });
* }}}
+ *
* @since 2.0.0
*/
@InterfaceStability.Evolving
@@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable {
// TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API.
/**
- * Called when starting to process one partition of new data in the executor. The `version` is
- * for data deduplication when there are failures. When recovering from a failure, some data may
- * be generated multiple times but they will always have the same version.
- *
- * If this method finds using the `partitionId` and `version` that this partition has already been
- * processed, it can return `false` to skip the further data processing. However, `close` still
- * will be called for cleaning up resources.
+ * Called when starting to process one partition of new data in the executor. See the class
+ * docs for more information on how to use the `partitionId` and `epochId`.
*
* @param partitionId the partition id.
- * @param version a unique id for data deduplication.
+ * @param epochId a unique id for data deduplication.
* @return `true` if the corresponding partition and version id should be processed. `false`
* indicates the partition should be skipped.
*/
- def open(partitionId: Long, version: Long): Boolean
+ def open(partitionId: Long, epochId: Long): Boolean
/**
- * Called to process the data in the executor side. This method will be called only when `open`
+ * Called to process the data in the executor side. This method will be called only if `open`
* returns `true`.
*/
def process(value: T): Unit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 6bab21dca0cbd..36f6038aa9485 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
private implicit val kExprEnc = encoderFor(kEncoder)
private implicit val vExprEnc = encoderFor(vEncoder)
- private def logicalPlan = queryExecution.analyzed
+ private def logicalPlan = AnalysisBarrier(queryExecution.analyzed)
private def sparkSession = queryExecution.sparkSession
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7147798d99533..c6449cd5a16b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql](
groupType match {
case RelationalGroupedDataset.GroupByType =>
Dataset.ofRows(
- df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier))
case RelationalGroupedDataset.RollupType =>
Dataset.ofRows(
- df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier))
case RelationalGroupedDataset.CubeType =>
Dataset.ofRows(
- df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
+ df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier))
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
Dataset.ofRows(
- df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
+ df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier))
}
}
@@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql](
df.exprEnc.schema,
groupingAttributes,
df.logicalPlan.output,
- df.logicalPlan))
+ df.planWithBarrier))
}
/**
@@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql](
case other => Alias(other, other.toString)()
}
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
- val child = df.logicalPlan
+ val child = df.planWithBarrier
val project = Project(groupingNamedExpressions ++ child.output, child)
val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 734573ba31f71..565042fcf762e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
-import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
+import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
@@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CallSite, Utils}
/**
@@ -81,6 +81,9 @@ class SparkSession private(
@transient private[sql] val extensions: SparkSessionExtensions)
extends Serializable with Closeable with Logging { self =>
+ // The call site where this SparkSession was constructed.
+ private val creationSite: CallSite = Utils.getCallSite()
+
private[sql] def this(sc: SparkContext) {
this(sc, None, None, new SparkSessionExtensions)
}
@@ -763,7 +766,7 @@ class SparkSession private(
@InterfaceStability.Stable
-object SparkSession {
+object SparkSession extends Logging {
/**
* Builder for [[SparkSession]].
@@ -895,6 +898,7 @@ object SparkSession {
* @since 2.0.0
*/
def getOrCreate(): SparkSession = synchronized {
+ assertOnDriver()
// Get the session from current thread's active session.
var session = activeThreadSession.get()
if ((session ne null) && !session.sparkContext.isStopped) {
@@ -951,7 +955,8 @@ object SparkSession {
session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
- defaultSession.set(session)
+ setDefaultSession(session)
+ setActiveSession(session)
// Register a successfully instantiated context to the singleton. This should be at the
// end of the class definition so that the singleton is updated only if there is no
@@ -1016,16 +1021,45 @@ object SparkSession {
/**
* Returns the active SparkSession for the current thread, returned by the builder.
*
+ * @note Return None, when calling this function on executors
+ *
* @since 2.2.0
*/
- def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get)
+ def getActiveSession: Option[SparkSession] = {
+ if (TaskContext.get != null) {
+ // Return None when running on executors.
+ None
+ } else {
+ Option(activeThreadSession.get)
+ }
+ }
/**
* Returns the default SparkSession that is returned by the builder.
*
+ * @note Return None, when calling this function on executors
+ *
* @since 2.2.0
*/
- def getDefaultSession: Option[SparkSession] = Option(defaultSession.get)
+ def getDefaultSession: Option[SparkSession] = {
+ if (TaskContext.get != null) {
+ // Return None when running on executors.
+ None
+ } else {
+ Option(defaultSession.get)
+ }
+ }
+
+ /**
+ * Returns the currently active SparkSession, otherwise the default one. If there is no default
+ * SparkSession, throws an exception.
+ *
+ * @since 2.4.0
+ */
+ def active: SparkSession = {
+ getActiveSession.getOrElse(getDefaultSession.getOrElse(
+ throw new IllegalStateException("No active or default Spark session found")))
+ }
////////////////////////////////////////////////////////////////////////////////////////
// Private methods from now on
@@ -1047,6 +1081,14 @@ object SparkSession {
}
}
+ private def assertOnDriver(): Unit = {
+ if (Utils.isTesting && TaskContext.get != null) {
+ // we're accessing it during task execution, fail.
+ throw new IllegalStateException(
+ "SparkSession should only be created and accessed on the driver.")
+ }
+ }
+
/**
* Helper method to create an instance of `SessionState` based on `className` from conf.
* The result is either `SessionState` or a Hive based `SessionState`.
@@ -1078,4 +1120,20 @@ object SparkSession {
}
}
+ private[spark] def cleanupAnyExistingSession(): Unit = {
+ val session = getActiveSession.orElse(getDefaultSession)
+ if (session.isDefined) {
+ logWarning(
+ s"""An existing Spark session exists as the active or default session.
+ |This probably means another suite leaked it. Attempting to stop it before continuing.
+ |This existing Spark session was created at:
+ |
+ |${session.get.creationSite.longForm}
+ |
+ """.stripMargin)
+ session.get.stop()
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index d68aeb275afda..93bf91e56f1bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -71,7 +71,7 @@ class CacheManager extends Logging {
/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
- cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
+ cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache())
cachedData.clear()
}
@@ -99,7 +99,7 @@ class CacheManager extends Logging {
sparkSession.sessionState.conf.columnBatchSize, storageLevel,
sparkSession.sessionState.executePlan(planToCache).executedPlan,
tableName,
- planToCache.stats)
+ planToCache)
cachedData.add(CachedData(planToCache, inMemoryRelation))
}
}
@@ -119,7 +119,7 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (cd.plan.find(_.sameResult(plan)).isDefined) {
- cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
it.remove()
}
}
@@ -138,17 +138,15 @@ class CacheManager extends Logging {
while (it.hasNext) {
val cd = it.next()
if (condition(cd.plan)) {
- cd.cachedRepresentation.cachedColumnBuffers.unpersist()
+ cd.cachedRepresentation.cacheBuilder.clearCache()
// Remove the cache entry before we create a new one, so that we can have a different
// physical plan.
it.remove()
+ val plan = spark.sessionState.executePlan(cd.plan).executedPlan
val newCache = InMemoryRelation(
- useCompression = cd.cachedRepresentation.useCompression,
- batchSize = cd.cachedRepresentation.batchSize,
- storageLevel = cd.cachedRepresentation.storageLevel,
- child = spark.sessionState.executePlan(cd.plan).executedPlan,
- tableName = cd.cachedRepresentation.tableName,
- statsOfPlanToCache = cd.plan.stats)
+ cacheBuilder = cd.cachedRepresentation
+ .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null),
+ logicalPlan = cd.plan)
needToRecache += cd.copy(cachedRepresentation = newCache)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 04f2619ed7541..48abad9078650 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -49,20 +50,24 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
ordinal: String,
dataType: DataType,
nullable: Boolean): ExprCode = {
- val javaType = ctx.javaType(dataType)
- val value = ctx.getValueFromVector(columnVar, dataType, ordinal)
- val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
+ val javaType = CodeGenerator.javaType(dataType)
+ val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
+ val isNullVar = if (nullable) {
+ JavaCode.isNullVariable(ctx.freshName("isNull"))
+ } else {
+ FalseLiteral
+ }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
- val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
- s"""
+ val code = code"${ctx.registerComment(str)}" + (if (nullable) {
+ code"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
- $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
+ $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
- s"$javaType $valueVar = $value;"
- }).trim
- ExprCode(code, isNullVar, valueVar)
+ code"$javaType $valueVar = $value;"
+ })
+ ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
}
/**
@@ -85,12 +90,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
- val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0
+ val scanTimeTotalNs =
+ ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0
val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")
- val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0
+ val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs = vectorTypes.getOrElse(
Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index ba1157d5b6a49..d7f2654be0451 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.BitSet
trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
val relation: BaseRelation
@@ -69,7 +70,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
- Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text)
+ Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text)
}
}
@@ -151,6 +152,7 @@ case class RowDataSourceScanExec(
* @param output Output attributes of the scan, including data attributes and partition attributes.
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
* @param partitionFilters Predicates to use for partition pruning.
+ * @param optionalBucketSet Bucket ids for bucket pruning
* @param dataFilters Filters on non-partition columns.
* @param tableIdentifier identifier for the table in the metastore.
*/
@@ -159,6 +161,7 @@ case class FileSourceScanExec(
output: Seq[Attribute],
requiredSchema: StructType,
partitionFilters: Seq[Expression],
+ optionalBucketSet: Option[BitSet],
dataFilters: Seq[Expression],
override val tableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with ColumnarBatchScan {
@@ -286,7 +289,20 @@ case class FileSourceScanExec(
} getOrElse {
metadata
}
- withOptPartitionCount
+
+ val withSelectedBucketsCount = relation.bucketSpec.map { spec =>
+ val numSelectedBuckets = optionalBucketSet.map { b =>
+ b.cardinality()
+ } getOrElse {
+ spec.numBuckets
+ }
+ withOptPartitionCount + ("SelectedBucketsCount" ->
+ s"$numSelectedBuckets out of ${spec.numBuckets}")
+ } getOrElse {
+ withOptPartitionCount
+ }
+
+ withSelectedBucketsCount
}
private lazy val inputRDD: RDD[InternalRow] = {
@@ -365,7 +381,7 @@ case class FileSourceScanExec(
selectedPartitions: Seq[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
- val bucketed =
+ val filesGroupedToBuckets =
selectedPartitions.flatMap { p =>
p.files.map { f =>
val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
@@ -377,8 +393,17 @@ case class FileSourceScanExec(
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}
+ val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
+ val bucketSet = optionalBucketSet.get
+ filesGroupedToBuckets.filter {
+ f => bucketSet.get(f._1)
+ }
+ } else {
+ filesGroupedToBuckets
+ }
+
val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
- FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
+ FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil))
}
new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
@@ -444,29 +469,16 @@ case class FileSourceScanExec(
currentSize = 0
}
- def addFile(file: PartitionedFile): Unit = {
- currentFiles += file
- currentSize += file.length + openCostInBytes
- }
-
- var frontIndex = 0
- var backIndex = splitFiles.length - 1
-
- while (frontIndex <= backIndex) {
- addFile(splitFiles(frontIndex))
- frontIndex += 1
- while (frontIndex <= backIndex &&
- currentSize + splitFiles(frontIndex).length <= maxSplitBytes) {
- addFile(splitFiles(frontIndex))
- frontIndex += 1
- }
- while (backIndex > frontIndex &&
- currentSize + splitFiles(backIndex).length <= maxSplitBytes) {
- addFile(splitFiles(backIndex))
- backIndex -= 1
+ // Assign files to partitions using "Next Fit Decreasing"
+ splitFiles.foreach { file =>
+ if (currentSize + file.length > maxSplitBytes) {
+ closePartition()
}
- closePartition()
+ // Add the given file to the current partition.
+ currentSize += file.length + openCostInBytes
+ currentFiles += file
}
+ closePartition()
new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
}
@@ -516,6 +528,7 @@ case class FileSourceScanExec(
output.map(QueryPlan.normalizeExprId(_, output)),
requiredSchema,
QueryPlan.normalizePredicates(partitionFilters, output),
+ optionalBucketSet,
QueryPlan.normalizePredicates(dataFilters, output),
None)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index f3555508185fe..be50a1571a2ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -125,7 +125,7 @@ case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow],
outputPartitioning: Partitioning = UnknownPartitioning(0),
- outputOrdering: Seq[SortOrder] = Nil,
+ override val outputOrdering: Seq[SortOrder] = Nil,
override val isStreaming: Boolean = false)(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index a7bd5ebf93ecd..5b4edf5136e3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -21,7 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -152,11 +153,15 @@ case class ExpandExec(
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
- val code = s"""
+ val code = code"""
|boolean $isNull = true;
- |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)};
+ |${CodeGenerator.javaType(firstExpr.dataType)} $value =
+ | ${CodeGenerator.defaultValue(firstExpr.dataType)};
""".stripMargin
- ExprCode(code, isNull, value)
+ ExprCode(
+ code,
+ JavaCode.isNullVariable(isNull),
+ JavaCode.variable(value, firstExpr.dataType))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 0c2c4a1a9100d..2549b9e1537a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+import org.apache.spark.sql.types._
/**
* For lazy computing, be sure the generator.terminate() called in the very last
@@ -170,9 +171,11 @@ case class GenerateExec(
// Add position
val position = if (e.position) {
if (outer) {
- Seq(ExprCode("", s"$index == -1", index))
+ Seq(ExprCode(
+ JavaCode.isNullExpression(s"$index == -1"),
+ JavaCode.variable(index, IntegerType)))
} else {
- Seq(ExprCode("", "false", index))
+ Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType)))
}
} else {
Seq.empty
@@ -305,19 +308,19 @@ case class GenerateExec(
nullable: Boolean,
initialChecks: Seq[String]): ExprCode = {
val value = ctx.freshName(name)
- val javaType = ctx.javaType(dt)
- val getter = ctx.getValue(source, dt, index)
+ val javaType = CodeGenerator.javaType(dt)
+ val getter = CodeGenerator.getValue(source, dt, index)
val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
if (checks.nonEmpty) {
val isNull = ctx.freshName("isNull")
val code =
- s"""
+ code"""
|boolean $isNull = ${checks.mkString(" || ")};
- |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
+ |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
""".stripMargin
- ExprCode(code, isNull, value)
+ ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt))
} else {
- ExprCode(s"$javaType $value = $getter;", "false", value)
+ ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 514ad7018d8c7..448eb703eacde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Physical plan node for scanning data from a local collection.
+ *
+ * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows`
+ * to the executors. Thus marking them as transient.
*/
case class LocalTableScanExec(
output: Seq[Attribute],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index 18f6f697bc857..3ca03ab2939aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql.execution
+import java.util.Locale
+
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
@@ -46,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
}
plan.transform {
- case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) =>
+ case a @ Aggregate(_, aggExprs, child @ PhysicalOperation(
+ projectList, filters, PartitionedRelation(partAttrs, rel))) =>
// We only apply this optimization when only partitioned attributes are scanned.
- if (a.references.subsetOf(partAttrs)) {
+ if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) {
+ // The project list and filters all only refer to partition attributes, which means the
+ // the Aggregator operator can also only refer to partition attributes, and filters are
+ // all partition filters. This is a metadata only query we can optimize.
val aggFunctions = aggExprs.flatMap(_.collect {
case agg: AggregateExpression => agg
})
@@ -64,7 +72,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
})
}
if (isAllDistinctAgg) {
- a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation)))
+ a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters)))
} else {
a
}
@@ -80,8 +88,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
private def getPartitionAttrs(
partitionColumnNames: Seq[String],
relation: LogicalPlan): Seq[Attribute] = {
- val partColumns = partitionColumnNames.map(_.toLowerCase).toSet
- relation.output.filter(a => partColumns.contains(a.name.toLowerCase))
+ val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap
+ partitionColumnNames.map { colName =>
+ attrMap.getOrElse(colName.toLowerCase(Locale.ROOT),
+ throw new AnalysisException(s"Unable to find the column `$colName` " +
+ s"given [${relation.output.map(_.name).mkString(", ")}]")
+ )
+ }
}
/**
@@ -90,13 +103,23 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
*/
private def replaceTableScanWithPartitionMetadata(
child: LogicalPlan,
- relation: LogicalPlan): LogicalPlan = {
+ relation: LogicalPlan,
+ partFilters: Seq[Expression]): LogicalPlan = {
+ // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the
+ // relation's schema. PartitionedRelation ensures that the filters only reference partition cols
+ val normalizedFilters = partFilters.map { e =>
+ e transform {
+ case a: AttributeReference =>
+ a.withName(relation.output.find(_.semanticEquals(a)).get.name)
+ }
+ }
+
child transform {
case plan if plan eq relation =>
relation match {
case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) =>
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
- val partitionData = fsRelation.location.listFiles(Nil, Nil)
+ val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil)
LocalRelation(partAttrs, partitionData.map(_.values), isStreaming)
case relation: HiveTableRelation =>
@@ -105,7 +128,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
CaseInsensitiveMap(relation.tableMeta.storage.properties)
val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(SQLConf.get.sessionLocalTimeZone)
- val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p =>
+ val partitions = if (partFilters.nonEmpty) {
+ catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters)
+ } else {
+ catalog.listPartitions(relation.tableMeta.identifier)
+ }
+
+ val partitionData = partitions.map { p =>
InternalRow.fromSeq(partAttrs.map { attr =>
Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval()
})
@@ -122,34 +151,23 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
/**
* A pattern that finds the partitioned table relation node inside the given plan, and returns a
* pair of the partition attributes and the table relation node.
- *
- * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with
- * deterministic expressions, and returns result after reaching the partitioned table relation
- * node.
*/
- object PartitionedRelation {
-
- def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match {
- case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)
- if fsRelation.partitionSchema.nonEmpty =>
- val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
- Some((AttributeSet(partAttrs), l))
+ object PartitionedRelation extends PredicateHelper {
- case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
- val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
- Some((AttributeSet(partAttrs), relation))
+ def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = {
+ plan match {
+ case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)
+ if fsRelation.partitionSchema.nonEmpty =>
+ val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l))
+ Some((partAttrs, l))
- case p @ Project(projectList, child) if projectList.forall(_.deterministic) =>
- unapply(child).flatMap { case (partAttrs, relation) =>
- if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None
- }
-
- case f @ Filter(condition, child) if condition.deterministic =>
- unapply(child).flatMap { case (partAttrs, relation) =>
- if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None
- }
+ case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty =>
+ val partAttrs = AttributeSet(
+ getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation))
+ Some((partAttrs, relation))
- case _ => None
+ case _ => None
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 7cae24bf5976c..3112b306c365e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -155,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
case (decimal, DecimalType()) => decimal.toString
+ case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
@@ -178,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
+ case (interval, CalendarIntervalType) => interval.toString
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
}
}
@@ -223,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
* Redact the sensitive information in the given string.
*/
private def withRedaction(message: String): String = {
- Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message)
+ Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message)
}
/** A special namespace for commands that can be used to debug query execution. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala
index 16806c620635f..cffd97baea6a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala
@@ -17,4 +17,5 @@
package org.apache.spark.sql.execution
-class QueryExecutionException(message: String) extends Exception(message)
+class QueryExecutionException(message: String, cause: Throwable = null)
+ extends Exception(message, cause)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index e991da7df0bde..439932b0cc3ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -68,16 +68,18 @@ object SQLExecution {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at :0"
- val callSite = sparkSession.sparkContext.getCallSite()
+ val callSite = sc.getCallSite()
- sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
- executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
- SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
- try {
- body
- } finally {
- sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
- executionId, System.currentTimeMillis()))
+ withSQLConfPropagated(sparkSession) {
+ sc.listenerBus.post(SparkListenerSQLExecutionStart(
+ executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
+ SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
+ try {
+ body
+ } finally {
+ sc.listenerBus.post(SparkListenerSQLExecutionEnd(
+ executionId, System.currentTimeMillis()))
+ }
}
} finally {
executionIdToQueryExecution.remove(executionId)
@@ -88,15 +90,43 @@ object SQLExecution {
/**
* Wrap an action with a known executionId. When running a different action in a different
* thread from the original one, this method can be used to connect the Spark jobs in this action
- * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`.
+ * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
*/
- def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
+ def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
+ val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ withSQLConfPropagated(sparkSession) {
+ try {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+ body
+ } finally {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ }
+ }
+ }
+
+ /**
+ * Wrap an action with specified SQL configs. These configs will be propagated to the executor
+ * side via job local properties.
+ */
+ def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
+ val sc = sparkSession.sparkContext
+ // Set all the specified SQL configs to local properties, so that they can be available at
+ // the executor side.
+ val allConfigs = sparkSession.sessionState.conf.getAllConfs
+ val originalLocalProps = allConfigs.collect {
+ case (key, value) if key.startsWith("spark") =>
+ val originalValue = sc.getLocalProperty(key)
+ sc.setLocalProperty(key, value)
+ (key, originalValue)
+ }
+
try {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
body
} finally {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ for ((key, value) <- originalLocalProps) {
+ sc.setLocalProperty(key, value)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index ac1c34d41c4f1..0dc16ba5ce281 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -133,7 +133,8 @@ case class SortExec(
override def needStopCheck: Boolean = false
override protected def doProduce(ctx: CodegenContext): String = {
- val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
+ val needToSort =
+ ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 1c8e4050978dc..00ff4c8ac310b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
-import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
class SparkOptimizer(
@@ -32,8 +31,7 @@ class SparkOptimizer(
override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
- Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
- Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++
+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
postHocOptimizationBatches :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 74048871f8d42..75f5ec0e253df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -41,6 +41,7 @@ class SparkPlanner(
DataSourceStrategy(conf) ::
SpecialLimits ::
Aggregation ::
+ Window ::
JoinSelection ::
InMemoryScans ::
BasicOperators :: Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 82b4eb9fba242..d6951ad01fb0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object SpecialLimits extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ReturnAnswer(rootPlan) => rootPlan match {
- case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ case Limit(IntegerLiteral(limit), Sort(order, true, child))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), child) =>
// With whole stage codegen, Spark releases resources only when all the output data of the
@@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
case other => planLater(other) :: Nil
}
- case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ case Limit(IntegerLiteral(limit), Sort(order, true, child))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case _ => Nil
}
@@ -323,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case PhysicalAggregation(
namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
- if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
+ if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) {
throw new AnalysisException(
"Streaming aggregation doesn't support group aggregate pandas UDF")
}
@@ -361,7 +365,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case Join(left, right, _, _) if left.isStreaming && right.isStreaming =>
throw new AnalysisException(
- "Stream stream joins without equality predicate is not supported", plan = Some(plan))
+ "Stream-stream join without equality predicate is not supported", plan = Some(plan))
case _ => Nil
}
@@ -380,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
- if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
- // column sets. Our MultipleDistinctRewriter should take care this case.
+ // column sets. Our `RewriteDistinctAggregates` should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
"Spark user mailing list.")
}
@@ -424,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ object Window extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalWindow(
+ WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
+ execution.window.WindowExec(
+ windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
+
+ case PhysicalWindow(
+ WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) =>
+ execution.python.WindowInPandasExec(
+ windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
+
+ case _ => Nil
+ }
+ }
+
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object InMemoryScans extends Strategy {
@@ -544,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
- case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
- execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data, _) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 0e525b1e22eb9..372dc3db36ce6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -111,7 +112,7 @@ trait CodegenSupport extends SparkPlan {
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
if (row != null) {
- ExprCode("", "false", row)
+ ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow]))
} else {
if (colVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
@@ -122,14 +123,14 @@ trait CodegenSupport extends SparkPlan {
ctx.INPUT_ROW = row
ctx.currentVars = colVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- val code = s"""
+ val code = code"""
|$evaluateInputs
- |${ev.code.trim}
- """.stripMargin.trim
- ExprCode(code, "false", ev.value)
+ |${ev.code}
+ """.stripMargin
+ ExprCode(code, FalseLiteral, ev.value)
} else {
- // There is no columns
- ExprCode("", "false", "unsafeRow")
+ // There are no columns
+ ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow]))
}
}
}
@@ -174,8 +175,9 @@ trait CodegenSupport extends SparkPlan {
// declaration.
val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
val requireAllOutput = output.forall(parent.usedInputs.contains(_))
- val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0)
- val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) {
+ val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0)
+ val consumeFunc = if (confEnabled && requireAllOutput
+ && CodeGenerator.isValidParamLength(paramLength)) {
constructDoConsumeFunction(ctx, inputVars, row)
} else {
parent.doConsume(ctx, inputVars, rowVar)
@@ -234,21 +236,21 @@ trait CodegenSupport extends SparkPlan {
variables.zipWithIndex.foreach { case (ev, i) =>
val paramName = ctx.freshName(s"expr_$i")
- val paramType = ctx.javaType(attributes(i).dataType)
+ val paramType = CodeGenerator.javaType(attributes(i).dataType)
arguments += ev.value
parameters += s"$paramType $paramName"
val paramIsNull = if (!attributes(i).nullable) {
// Use constant `false` without passing `isNull` for non-nullable variable.
- "false"
+ FalseLiteral
} else {
val isNull = ctx.freshName(s"exprIsNull_$i")
arguments += ev.isNull
parameters += s"boolean $isNull"
- isNull
+ JavaCode.isNullVariable(isNull)
}
- paramVars += ExprCode("", paramIsNull, paramName)
+ paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType))
}
(arguments, parameters, paramVars)
}
@@ -258,8 +260,8 @@ trait CodegenSupport extends SparkPlan {
* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
- val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
- variables.foreach(_.code = "")
+ val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
+ variables.foreach(_.code = EmptyBlock)
evaluate
}
@@ -274,8 +276,8 @@ trait CodegenSupport extends SparkPlan {
val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars.append(ev.code.trim + "\n")
- ev.code = ""
+ evaluateVars.append(ev.code.toString + "\n")
+ ev.code = EmptyBlock
}
}
evaluateVars.toString()
@@ -540,7 +542,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
${ctx.registerComment(
s"""Codegend pipeline for stage (id=$codegenStageId)
- |${this.treeString.trim}""".stripMargin)}
+ |${this.treeString.trim}""".stripMargin,
+ "wsc_codegenPipeline")}
+ ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)}
final class $className extends ${classOf[BufferedRowIterator].getName} {
private Object[] references;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index ce3c68810f3b6..8c7b2c187cccd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -178,7 +179,7 @@ case class HashAggregateExec(
private var bufVars: Seq[ExprCode] = _
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
+ val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
// The generated function doesn't have input row in the code context.
ctx.INPUT_ROW = null
@@ -186,15 +187,18 @@ case class HashAggregateExec(
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
- val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
- val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
+ val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
+ val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
- val initVars = s"""
+ val initVars = code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
- ExprCode(ev.code + initVars, isNull, value)
+ ExprCode(
+ ev.code + initVars,
+ JavaCode.isNullGlobal(isNull),
+ JavaCode.global(value, e.dataType))
}
val initBufVar = evaluateVariables(bufVars)
@@ -532,7 +536,7 @@ case class HashAggregateExec(
*/
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
val isSupported =
- (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
+ (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
@@ -565,7 +569,7 @@ case class HashAggregateExec(
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
+ val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (sqlContext.conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else {
@@ -752,12 +756,15 @@ case class HashAggregateExec(
}
// generate hash code for key
- val hashExpr = Murmur3Hash(groupingExpressions, 42)
+ // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions
+ // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n,
+ // pick a different seed to avoid this conflict
+ val hashExpr = Murmur3Hash(groupingExpressions, 48)
val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx)
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
- val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter")
+ val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
@@ -767,8 +774,8 @@ case class HashAggregateExec(
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
- |${unsafeRowKeyCode.code.trim}
- |${hashEval.code.trim}
+ |${unsafeRowKeyCode.code}
+ |${hashEval.code}
|if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
@@ -832,7 +839,7 @@ case class HashAggregateExec(
}
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
- ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
+ CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
|// common sub-expressions
@@ -855,7 +862,7 @@ case class HashAggregateExec(
}
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
- ctx.updateColumn(
+ CodeGenerator.updateColumn(
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
index 1c613b19c4ab1..e1c85823259b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -41,20 +42,23 @@ abstract class HashMapGenerator(
val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
val groupingKeySignature =
- groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
+ groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ")
val buffVars: Seq[ExprCode] = {
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
initExpr.map { e =>
- val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
- val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
+ val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
+ val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
val ev = e.genCode(ctx)
val initVars =
- s"""
+ code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
- ExprCode(ev.code + initVars, isNull, value)
+ ExprCode(
+ ev.code + initVars,
+ JavaCode.isNullGlobal(isNull),
+ JavaCode.global(value, e.dataType))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index fd25707dd4ca6..d5508275c48c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression}
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.types._
/**
@@ -114,7 +114,7 @@ class RowBasedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- s"""(${ctx.genEqual(key.dataType, ctx.getValue("row",
+ s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row",
key.dataType, ordinal.toString()), key.name)})"""
}.mkString(" && ")
}
@@ -147,7 +147,7 @@ class RowBasedHashMapGenerator(
case t: DecimalType =>
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
case t: DataType =>
- if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) {
+ if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) {
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
}
s"agg_rowWriter.write(${ordinal}, ${key.name})"
@@ -165,18 +165,14 @@ class RowBasedHashMapGenerator(
| if (buckets[idx] == -1) {
| if (numRows < capacity && !isBatchFull) {
| // creating the unsafe for new entry
- | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length});
- | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
- | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result,
- | ${numVarLenFields * 32});
| org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
| = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
- | agg_holder,
- | ${groupingKeySchema.length});
- | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
+ | ${groupingKeySchema.length}, ${numVarLenFields * 32});
+ | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed
| agg_rowWriter.zeroOutNullBytes();
| ${createUnsafeRowForKey};
- | agg_result.setTotalSize(agg_holder.totalSize());
+ | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
+ | = agg_rowWriter.getRow();
| Object kbase = agg_result.getBaseObject();
| long koff = agg_result.getBaseOffset();
| int klen = agg_result.getSizeInBytes();
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index aab8cc50b9526..6d44890704f49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
object TypedAggregateExpression {
def apply[BUF : Encoder, OUT : Encoder](
@@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction {
s"$nodeName($input)"
}
- override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$")
+ // aggregator.getClass.getSimpleName can cause Malformed class name error,
+ // call safer `Utils.getSimpleName` instead
+ override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$");
}
// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 633eeac180974..7b3580cecc60d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -127,7 +127,8 @@ class VectorizedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]")
+ val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType,
+ "buckets[idx]")
s"(${ctx.genEqual(key.dataType, value, key.name)})"
}.mkString(" && ")
}
@@ -182,14 +183,14 @@ class VectorizedHashMapGenerator(
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
+ CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
}
}
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
- buffVars(ordinal), nullable = true)
+ CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows",
+ key.dataType, buffVars(ordinal), nullable = true)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 22b63513548fe..66888fce7f9f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter {
valueVector match {
case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset()
case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset()
+ case listVector: ListVector =>
+ // Manual "reset" the underlying buffer.
+ // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call
+ // `listVector.reset()`.
+ val buffers = listVector.getBuffers(false)
+ buffers.foreach(buf => buf.setZero(0, buf.capacity()))
+ listVector.setValueCount(0)
+ listVector.setLastSet(0)
case _ =>
}
count = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index a15a8d11aa2a0..9434ceb7cd16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
@@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// generate better code (remove dead branches).
val resultVars = input.zipWithIndex.map { case (ev, i) =>
if (notNullAttributes.contains(child.output(i).exprId)) {
- ev.isNull = "false"
+ ev.isNull = FalseLiteral
}
ev
}
@@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override val output: Seq[Attribute] = range.output
+ override def outputOrdering: Seq[SortOrder] = range.outputOrdering
+
+ override def outputPartitioning: Partitioning = {
+ if (numElements > 0) {
+ if (numSlices == 1) {
+ SinglePartition
+ } else {
+ RangePartitioning(outputOrdering, numSlices)
+ }
+ } else {
+ UnknownPartitioning(0)
+ }
+ }
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -364,11 +378,11 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange")
- val number = ctx.addMutableState(ctx.JAVA_LONG, "number")
+ val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
+ val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
val value = ctx.freshName("value")
- val ev = ExprCode("", "false", value)
+ val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
val BigInt = classOf[java.math.BigInteger].getName
// Inline mutable state since not many Range operations in a task
@@ -385,10 +399,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// the metrics.
// Once number == batchEnd, it's time to progress to the next batch.
- val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd")
+ val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
// How many values should still be generated by this range operator.
- val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo")
+ val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
@@ -629,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
+ SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 4f28eeb725cbb..2d699e8a9d088 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val accessorName = ctx.addMutableState(accessorCls, "accessor")
val createCode = dt match {
- case t if ctx.isPrimitiveType(dt) =>
+ case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | StringType | BinaryType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
@@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null;
private byte[][] buffers = null;
- private UnsafeRow unsafeRow = new UnsafeRow($numFields);
- private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
- private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
+ private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields);
private MutableUnsafeRow mutableRow = null;
private int currentRow = 0;
@@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public InternalRow next() {
currentRow += 1;
- bufferHolder.reset();
+ rowWriter.reset();
rowWriter.zeroOutNullBytes();
${extractorCalls}
- unsafeRow.setTotalSize(bufferHolder.totalSize());
- return unsafeRow;
+ return rowWriter.getRow();
}
${ctx.declareAddedFunctions()}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 22e16913d4da9..da35a4734e65a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -24,26 +24,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.LongAccumulator
-object InMemoryRelation {
- def apply(
- useCompression: Boolean,
- batchSize: Int,
- storageLevel: StorageLevel,
- child: SparkPlan,
- tableName: Option[String],
- statsOfPlanToCache: Statistics): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)(
- statsOfPlanToCache = statsOfPlanToCache)
-}
-
-
/**
* CachedBatch is a cached batch of rows.
*
@@ -54,47 +42,41 @@ object InMemoryRelation {
private[columnar]
case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
-case class InMemoryRelation(
- output: Seq[Attribute],
+case class CachedRDDBuilder(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
- @transient child: SparkPlan,
+ @transient cachedPlan: SparkPlan,
tableName: Option[String])(
- @transient var _cachedColumnBuffers: RDD[CachedBatch] = null,
- val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator,
- statsOfPlanToCache: Statistics)
- extends logical.LeafNode with MultiInstanceRelation {
+ @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) {
- override protected def innerChildren: Seq[SparkPlan] = Seq(child)
+ val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator
- override def producedAttributes: AttributeSet = outputSet
-
- @transient val partitionStatistics = new PartitionStatistics(output)
-
- override def computeStats(): Statistics = {
- if (sizeInBytesStats.value == 0L) {
- // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
- // Note that we should drop the hint info here. We may cache a plan whose root node is a hint
- // node. When we lookup the cache with a semantically same plan without hint info, the plan
- // returned by cache lookup should not have hint info. If we lookup the cache with a
- // semantically same plan with a different hint info, `CacheManager.useCachedData` will take
- // care of it and retain the hint info in the lookup input plan.
- statsOfPlanToCache.copy(hints = HintInfo())
- } else {
- Statistics(sizeInBytes = sizeInBytesStats.value.longValue)
+ def cachedColumnBuffers: RDD[CachedBatch] = {
+ if (_cachedColumnBuffers == null) {
+ synchronized {
+ if (_cachedColumnBuffers == null) {
+ _cachedColumnBuffers = buildBuffers()
+ }
+ }
}
+ _cachedColumnBuffers
}
- // If the cached column buffers were not passed in, we calculate them in the constructor.
- // As in Spark, the actual work of caching is lazy.
- if (_cachedColumnBuffers == null) {
- buildBuffers()
+ def clearCache(blocking: Boolean = true): Unit = {
+ if (_cachedColumnBuffers != null) {
+ synchronized {
+ if (_cachedColumnBuffers != null) {
+ _cachedColumnBuffers.unpersist(blocking)
+ _cachedColumnBuffers = null
+ }
+ }
+ }
}
- private def buildBuffers(): Unit = {
- val output = child.output
- val cached = child.execute().mapPartitionsInternal { rowIterator =>
+ private def buildBuffers(): RDD[CachedBatch] = {
+ val output = cachedPlan.output
+ val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>
new Iterator[CachedBatch] {
def next(): CachedBatch = {
val columnBuilders = output.map { attribute =>
@@ -142,31 +124,77 @@ case class InMemoryRelation(
cached.setName(
tableName.map(n => s"In-memory table $n")
- .getOrElse(StringUtils.abbreviate(child.toString, 1024)))
- _cachedColumnBuffers = cached
+ .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024)))
+ cached
+ }
+}
+
+object InMemoryRelation {
+
+ def apply(
+ useCompression: Boolean,
+ batchSize: Int,
+ storageLevel: StorageLevel,
+ child: SparkPlan,
+ tableName: Option[String],
+ logicalPlan: LogicalPlan): InMemoryRelation = {
+ val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)()
+ new InMemoryRelation(child.output, cacheBuilder)(
+ statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
+ }
+
+ def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = {
+ new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)(
+ statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering)
+ }
+}
+
+case class InMemoryRelation(
+ output: Seq[Attribute],
+ @transient cacheBuilder: CachedRDDBuilder)(
+ statsOfPlanToCache: Statistics,
+ override val outputOrdering: Seq[SortOrder])
+ extends logical.LeafNode with MultiInstanceRelation {
+
+ override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan)
+
+ override def doCanonicalize(): logical.LogicalPlan =
+ copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)),
+ cacheBuilder)(
+ statsOfPlanToCache,
+ outputOrdering)
+
+ override def producedAttributes: AttributeSet = outputSet
+
+ @transient val partitionStatistics = new PartitionStatistics(output)
+
+ def cachedPlan: SparkPlan = cacheBuilder.cachedPlan
+
+ override def computeStats(): Statistics = {
+ if (cacheBuilder.sizeInBytesStats.value == 0L) {
+ // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache.
+ // Note that we should drop the hint info here. We may cache a plan whose root node is a hint
+ // node. When we lookup the cache with a semantically same plan without hint info, the plan
+ // returned by cache lookup should not have hint info. If we lookup the cache with a
+ // semantically same plan with a different hint info, `CacheManager.useCachedData` will take
+ // care of it and retain the hint info in the lookup input plan.
+ statsOfPlanToCache.copy(hints = HintInfo())
+ } else {
+ Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue)
+ }
}
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
- InMemoryRelation(
- newOutput, useCompression, batchSize, storageLevel, child, tableName)(
- _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
+ InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering)
}
override def newInstance(): this.type = {
new InMemoryRelation(
output.map(_.newInstance()),
- useCompression,
- batchSize,
- storageLevel,
- child,
- tableName)(
- _cachedColumnBuffers,
- sizeInBytesStats,
- statsOfPlanToCache).asInstanceOf[this.type]
+ cacheBuilder)(
+ statsOfPlanToCache,
+ outputOrdering).asInstanceOf[this.type]
}
- def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
-
- override protected def otherCopyArgs: Seq[AnyRef] =
- Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)
+ override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index a93e8a1ad954d..0b4dd76c7d860 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
-import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.vectorized._
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -38,6 +38,11 @@ case class InMemoryTableScanExec(
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren
+ override def doCanonicalize(): SparkPlan =
+ copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)),
+ predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)),
+ relation = relation.canonicalized.asInstanceOf[InMemoryRelation])
+
override def vectorTypes: Option[Seq[String]] =
Option(Seq.fill(attributes.length)(
if (!conf.offHeapColumnVectorEnabled) {
@@ -73,10 +78,12 @@ case class InMemoryTableScanExec(
private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i)))
- private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = {
+ private def createAndDecompressColumn(
+ cachedColumnarBatch: CachedBatch,
+ offHeapColumnVectorEnabled: Boolean): ColumnarBatch = {
val rowCount = cachedColumnarBatch.numRows
val taskContext = Option(TaskContext.get())
- val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) {
+ val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) {
OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
} else {
OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema)
@@ -96,10 +103,13 @@ case class InMemoryTableScanExec(
private lazy val inputRDD: RDD[InternalRow] = {
val buffers = filteredCachedBatches()
+ val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled
if (supportsBatch) {
// HACK ALERT: This is actually an RDD[ColumnarBatch].
// We're taking advantage of Scala's type erasure here to pass these batches along.
- buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]]
+ buffers
+ .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled))
+ .asInstanceOf[RDD[InternalRow]]
} else {
val numOutputRows = longMetric("numOutputRows")
@@ -149,7 +159,7 @@ case class InMemoryTableScanExec(
private def updateAttribute(expr: Expression): Expression = {
// attributes can be pruned so using relation's output.
// E.g., relation.output is [id, item] but this scan's output can be [item] only.
- val attrMap = AttributeMap(relation.child.output.zip(relation.output))
+ val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output))
expr.transform {
case attr: Attribute => attrMap.getOrElse(attr, attr)
}
@@ -158,22 +168,24 @@ case class InMemoryTableScanExec(
// The cached version does not change the outputPartitioning of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputPartitioning: Partitioning = {
- relation.child.outputPartitioning match {
+ relation.cachedPlan.outputPartitioning match {
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
- case _ => relation.child.outputPartitioning
+ case _ => relation.cachedPlan.outputPartitioning
}
}
// The cached version does not change the outputOrdering of the original SparkPlan.
// But the cached version could alias output, so we need to replace output.
override def outputOrdering: Seq[SortOrder] =
- relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
+ relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])
- private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
+ // Keeps relation's partition statistics because we don't serialize relation.
+ private val stats = relation.partitionStatistics
+ private def statsFor(a: Attribute) = stats.forAttribute(a)
// Returned filter predicate should return false iff it is impossible for the input expression
// to evaluate to `true' based on statistics collected about this partition batch.
- @transient val buildFilter: PartialFunction[Expression, Expression] = {
+ @transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
case And(lhs: Expression, rhs: Expression)
if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
(buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
@@ -213,14 +225,14 @@ case class InMemoryTableScanExec(
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
- val partitionFilters: Seq[Expression] = {
+ lazy val partitionFilters: Seq[Expression] = {
predicates.flatMap { p =>
val filter = buildFilter.lift(p)
val boundFilter =
filter.map(
BindReferences.bindReference(
_,
- relation.partitionStatistics.schema,
+ stats.schema,
allowFailures = true))
boundFilter.foreach(_ =>
@@ -243,9 +255,9 @@ case class InMemoryTableScanExec(
private def filteredCachedBatches(): RDD[CachedBatch] = {
// Using these variables here to avoid serialization of entire objects (if referenced directly)
// within the map Partitions closure.
- val schema = relation.partitionStatistics.schema
+ val schema = stats.schema
val schemaIndex = schema.zipWithIndex
- val buffers = relation.cachedColumnBuffers
+ val buffers = relation.cacheBuilder.cachedColumnBuffers
buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
val partitionFilter = newPredicate(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
index 79dcf3a6105ce..00a1d54b41709 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
@@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme {
while (pos < capacity) {
if (pos != nextNullIndex) {
val len = nextNullIndex - pos
- assert(len * unitSize < Int.MaxValue)
+ assert(len * unitSize.toLong < Int.MaxValue)
putFunction(columnVector, pos, bufferPos, len)
bufferPos += len * unitSize
pos += len
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 1122522ccb4cb..640e01336aa75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command
import scala.collection.mutable
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType}
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
/**
@@ -64,12 +66,12 @@ case class AnalyzeColumnCommand(
/**
* Compute stats for the given columns.
- * @return (row count, map from column name to ColumnStats)
+ * @return (row count, map from column name to CatalogColumnStats)
*/
private def computeColumnStats(
sparkSession: SparkSession,
tableIdent: TableIdentifier,
- columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = {
+ columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = {
val conf = sparkSession.sessionState.conf
val relation = sparkSession.table(tableIdent).logicalPlan
@@ -81,7 +83,7 @@ case class AnalyzeColumnCommand(
// Make sure the column types are supported for stats gathering.
attributesToAnalyze.foreach { attr =>
- if (!ColumnStat.supportsType(attr.dataType)) {
+ if (!supportsType(attr.dataType)) {
throw new AnalysisException(
s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " +
"and Spark does not support statistics collection on this column type.")
@@ -103,7 +105,7 @@ case class AnalyzeColumnCommand(
// will be structs containing all column stats.
// The layout of each struct follows the layout of the ColumnStats.
val expressions = Count(Literal(1)).toAggregateExpression() +:
- attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles))
+ attributesToAnalyze.map(statExprs(_, conf, attributePercentiles))
val namedExpressions = expressions.map(e => Alias(e, e.toString)())
val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation))
@@ -111,9 +113,9 @@ case class AnalyzeColumnCommand(
val rowCount = statsRow.getLong(0)
val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
- // according to `ColumnStat.statExprs`, the stats struct always have 7 fields.
- (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount,
- attributePercentiles.get(attr)))
+ // according to `statExprs`, the stats struct always have 7 fields.
+ (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount,
+ attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType))
}.toMap
(rowCount, columnStats)
}
@@ -124,7 +126,7 @@ case class AnalyzeColumnCommand(
sparkSession: SparkSession,
relation: LogicalPlan): AttributeMap[ArrayData] = {
val attrsToGenHistogram = if (conf.histogramEnabled) {
- attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType))
+ attributesToAnalyze.filter(a => supportsHistogram(a.dataType))
} else {
Nil
}
@@ -154,4 +156,120 @@ case class AnalyzeColumnCommand(
AttributeMap(attributePercentiles.toSeq)
}
+ /** Returns true iff the we support gathering column statistics on column of the given type. */
+ private def supportsType(dataType: DataType): Boolean = dataType match {
+ case _: IntegralType => true
+ case _: DecimalType => true
+ case DoubleType | FloatType => true
+ case BooleanType => true
+ case DateType => true
+ case TimestampType => true
+ case BinaryType | StringType => true
+ case _ => false
+ }
+
+ /** Returns true iff the we support gathering histogram on column of the given type. */
+ private def supportsHistogram(dataType: DataType): Boolean = dataType match {
+ case _: IntegralType => true
+ case _: DecimalType => true
+ case DoubleType | FloatType => true
+ case DateType => true
+ case TimestampType => true
+ case _ => false
+ }
+
+ /**
+ * Constructs an expression to compute column statistics for a given column.
+ *
+ * The expression should create a single struct column with the following schema:
+ * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long,
+ * distinctCountsForIntervals: Array[Long]
+ *
+ * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and
+ * as a result should stay in sync with it.
+ */
+ private def statExprs(
+ col: Attribute,
+ conf: SQLConf,
+ colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = {
+ def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr =>
+ expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() }
+ })
+ val one = Literal(1, LongType)
+
+ // the approximate ndv (num distinct value) should never be larger than the number of rows
+ val numNonNulls = if (col.nullable) Count(col) else Count(one)
+ val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls))
+ val numNulls = Subtract(Count(one), numNonNulls)
+ val defaultSize = Literal(col.dataType.defaultSize, LongType)
+ val nullArray = Literal(null, ArrayType(LongType))
+
+ def fixedLenTypeStruct: CreateNamedStruct = {
+ val genHistogram =
+ supportsHistogram(col.dataType) && colPercentiles.contains(col)
+ val intervalNdvsExpr = if (genHistogram) {
+ ApproxCountDistinctForIntervals(col,
+ Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError)
+ } else {
+ nullArray
+ }
+ // For fixed width types, avg size should be the same as max size.
+ struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls,
+ defaultSize, defaultSize, intervalNdvsExpr)
+ }
+
+ col.dataType match {
+ case _: IntegralType => fixedLenTypeStruct
+ case _: DecimalType => fixedLenTypeStruct
+ case DoubleType | FloatType => fixedLenTypeStruct
+ case BooleanType => fixedLenTypeStruct
+ case DateType => fixedLenTypeStruct
+ case TimestampType => fixedLenTypeStruct
+ case BinaryType | StringType =>
+ // For string and binary type, we don't compute min, max or histogram
+ val nullLit = Literal(null, col.dataType)
+ struct(
+ ndv, nullLit, nullLit, numNulls,
+ // Set avg/max size to default size if all the values are null or there is no value.
+ Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)),
+ Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)),
+ nullArray)
+ case _ =>
+ throw new AnalysisException("Analyzing column statistics is not supported for column " +
+ s"${col.name} of data type: ${col.dataType}.")
+ }
+ }
+
+ /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */
+ private def rowToColumnStat(
+ row: InternalRow,
+ attr: Attribute,
+ rowCount: Long,
+ percentiles: Option[ArrayData]): ColumnStat = {
+ // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins.
+ val cs = ColumnStat(
+ distinctCount = Option(BigInt(row.getLong(0))),
+ // for string/binary min/max, get should return null
+ min = Option(row.get(1, attr.dataType)),
+ max = Option(row.get(2, attr.dataType)),
+ nullCount = Option(BigInt(row.getLong(3))),
+ avgLen = Option(row.getLong(4)),
+ maxLen = Option(row.getLong(5))
+ )
+ if (row.isNullAt(6) || cs.nullCount.isEmpty) {
+ cs
+ } else {
+ val ndvs = row.getArray(6).toLongArray()
+ assert(percentiles.get.numElements() == ndvs.length + 1)
+ val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble)
+ // Construct equi-height histogram
+ val bins = ndvs.zipWithIndex.map { case (ndv, i) =>
+ HistogramBin(endpoints(i), endpoints(i + 1), ndv)
+ }
+ val nonNullRows = rowCount - cs.nullCount.get
+ val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins)
+ cs.copy(histogram = Some(histogram))
+ }
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
index e56f8105fc9a7..e11dbd201004d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
@@ -45,15 +44,7 @@ trait DataWritingCommand extends Command {
// Output columns of the analyzed input query plan
def outputColumns: Seq[Attribute]
- lazy val metrics: Map[String, SQLMetric] = {
- val sparkContext = SparkContext.getActive.get
- Map(
- "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"),
- "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
- )
- }
+ lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 306f43dc4214a..f6ef433f2ce15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -21,7 +21,9 @@ import java.net.URI
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
- query: LogicalPlan)
- extends RunnableCommand {
-
- override protected def innerChildren: Seq[LogicalPlan] = Seq(query)
+ query: LogicalPlan,
+ outputColumns: Seq[Attribute])
+ extends DataWritingCommand {
- override def run(sparkSession: SparkSession): Seq[Row] = {
+ override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
assert(table.provider.isDefined)
@@ -163,24 +164,25 @@ case class CreateDataSourceTableAsSelectCommand(
}
saveDataIntoTable(
- sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true)
+ sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true)
} else {
assert(table.schema.isEmpty)
-
+ sparkSession.sessionState.catalog.validateTableLocation(table)
val tableLocation = if (table.tableType == CatalogTableType.MANAGED) {
Some(sessionState.catalog.defaultTablePath(table.identifier))
} else {
table.storage.locationUri
}
val result = saveDataIntoTable(
- sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false)
+ sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
// We will use the schema of resolved.relation as the schema of the table (instead of
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
schema = result.schema)
- sessionState.catalog.createTable(newTable, ignoreIfExists = false)
+ // Table location is already validated. No need to check it again during table creation.
+ sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false)
result match {
case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty &&
@@ -198,10 +200,10 @@ case class CreateDataSourceTableAsSelectCommand(
session: SparkSession,
table: CatalogTable,
tableLocation: Option[URI],
- data: LogicalPlan,
+ physicalPlan: SparkPlan,
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
- // Create the relation based on the input logical plan: `data`.
+ // Create the relation based on the input logical plan: `query`.
val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_))
val dataSource = DataSource(
session,
@@ -212,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)
try {
- dataSource.writeAndRead(mode, query)
+ dataSource.writeAndRead(mode, query, outputColumns, physicalPlan)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 0142f17ce62e2..bf4d96fa18d0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand(
val resolver = sparkSession.sessionState.conf.resolver
DDLUtils.verifyAlterTableType(catalog, table, isView = false)
- // Find the origin column from schema by column name.
- val originColumn = findColumnByName(table.schema, columnName, resolver)
+ // Find the origin column from dataSchema by column name.
+ val originColumn = findColumnByName(table.dataSchema, columnName, resolver)
// Throw an AnalysisException if the column name/dataType is changed.
if (!columnEqual(originColumn, newColumn, resolver)) {
throw new AnalysisException(
@@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand(
s"'${newColumn.name}' with type '${newColumn.dataType}'")
}
- val newSchema = table.schema.fields.map { field =>
+ val newDataSchema = table.dataSchema.fields.map { field =>
if (field.name == originColumn.name) {
// Create a new column from the origin column with the new comment.
addComment(field, newColumn.getComment)
@@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand(
field
}
}
- val newTable = table.copy(schema = StructType(newSchema))
- catalog.alterTable(newTable)
+ catalog.alterTableDataSchema(tableName, StructType(newDataSchema))
Seq.empty[Row]
}
@@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand(
schema.fields.collectFirst {
case field if resolver(field.name, name) => field
}.getOrElse(throw new AnalysisException(
- s"Invalid column reference '$name', table schema is '${schema}'"))
+ s"Can't find column `$name` given table data columns " +
+ s"${schema.fieldNames.mkString("[`", "`, `", "`]")}"))
}
// Add the comment to a column, if comment is empty, return the original column.
@@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand(
val root = new Path(table.location)
logInfo(s"Recover all the partitions in $root")
- val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration)
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ val fs = root.getFileSystem(hadoopConf)
val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt
- val hadoopConf = spark.sparkContext.hadoopConfiguration
val pathFilter = getPathFilter(hadoopConf)
val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8)
@@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand(
pathFilter: PathFilter,
threshold: Int): GenMap[String, PartitionStatistics] = {
if (partitionSpecsAndLocs.length > threshold) {
- val hadoopConf = spark.sparkContext.hadoopConfiguration
+ val hadoopConf = spark.sessionState.newHadoopConf()
val serializableConfiguration = new SerializableConfiguration(hadoopConf)
val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index e400975f19708..44749190c79eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -695,10 +695,11 @@ case class DescribeColumnCommand(
// Show column stats when EXTENDED or FORMATTED is specified.
buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL"))
buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL"))
- buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL"))
- buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL"))
- buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL"))
- buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL"))
+ buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL"))
+ buffer += Row("distinct_count",
+ cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL"))
+ buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL"))
+ buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL"))
val histDesc = for {
c <- cs
hist <- c.histogram
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
index 9dbbe9946ee99..ba7d2b7cbdb1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
@@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
/**
- * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]].
+ * Simple metrics collected during an instance of [[FileFormatDataWriter]].
* These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703).
*/
case class BasicWriteTaskStats(
@@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker(
totalNumOutput += summary.numRows
}
- metrics("numFiles").add(numFiles)
- metrics("numOutputBytes").add(totalNumBytes)
- metrics("numOutputRows").add(totalNumOutput)
- metrics("numParts").add(numPartitions)
+ metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles)
+ metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes)
+ metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput)
+ metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList)
}
}
+
+object BasicWriteJobStatsTracker {
+ private val NUM_FILES_KEY = "numFiles"
+ private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes"
+ private val NUM_OUTPUT_ROWS_KEY = "numOutputRows"
+ private val NUM_PARTS_KEY = "numParts"
+
+ def metrics: Map[String, SQLMetric] = {
+ val sparkContext = SparkContext.getActive.get
+ Map(
+ NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"),
+ NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
+ NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
+ )
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
index ea4fe9c8ade5f..a776fc3e7021d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.datasources
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+
object BucketingUtils {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
@@ -35,5 +38,16 @@ object BucketingUtils {
case other => None
}
+ // Given bucketColumn, numBuckets and value, returns the corresponding bucketId
+ def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
+ val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
+ mutableInternalRow.update(0, value)
+
+ val bucketIdGenerator = UnsafeProjection.create(
+ HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil,
+ bucketColumn :: Nil)
+ bucketIdGenerator(mutableInternalRow).getInt(0)
+ }
+
def bucketIdToString(id: Int): String = f"_$id%05d"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
index 4046396d0e614..a66a07673e25f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
@@ -85,7 +85,7 @@ class CatalogFileIndex(
sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs))
} else {
new InMemoryFileIndex(
- sparkSession, rootPaths, table.storage.properties, partitionSchema = None)
+ sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 25e1210504273..f16d824201e77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
import scala.language.{existentials, implicitConversions}
import scala.util.{Failure, Success, Try}
-import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
@@ -31,18 +30,21 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils}
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
+import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.Utils
@@ -100,24 +102,6 @@ case class DataSource(
bucket.sortColumnNames, "in the sort definition", equality)
}
- /**
- * In the read path, only managed tables by Hive provide the partition columns properly when
- * initializing this class. All other file based data sources will try to infer the partitioning,
- * and then cast the inferred types to user specified dataTypes if the partition columns exist
- * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
- * inconsistent data types as reported in SPARK-21463.
- * @param fileIndex A FileIndex that will perform partition inference
- * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
- */
- private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = {
- val resolved = fileIndex.partitionSchema.map { partitionField =>
- // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
- userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
- partitionField)
- }
- StructType(resolved)
- }
-
/**
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
* it. In the read path, only managed tables by Hive provide the partition columns properly when
@@ -137,31 +121,26 @@ case class DataSource(
* be any further inference in any triggers.
*
* @param format the file format object for this DataSource
- * @param fileStatusCache the shared cache for file statuses to speed up listing
+ * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list
* @return A pair of the data schema (excluding partition columns) and the schema of the partition
* columns.
*/
private def getOrInferFileFormatSchema(
format: FileFormat,
- fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = {
- // the operations below are expensive therefore try not to do them if we don't need to, e.g.,
+ fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = {
+ // The operations below are expensive therefore try not to do them if we don't need to, e.g.,
// in streaming mode, we have already inferred and registered partition columns, we will
// never have to materialize the lazy val below
- lazy val tempFileIndex = {
- val allPaths = caseInsensitiveOptions.get("path") ++ paths
- val hadoopConf = sparkSession.sessionState.newHadoopConf()
- val globbedPaths = allPaths.toSeq.flatMap { path =>
- val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(hadoopConf)
- val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- SparkHadoopUtil.get.globPathIfNecessary(fs, qualified)
- }.toArray
- new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache)
+ lazy val tempFileIndex = fileIndex.getOrElse {
+ val globbedPaths =
+ checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false)
+ createInMemoryFileIndex(globbedPaths)
}
+
val partitionSchema = if (partitionColumns.isEmpty) {
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
// columns properly unless it is a Hive DataSource
- combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex)
+ tempFileIndex.partitionSchema
} else {
// maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred
// partitioning
@@ -353,13 +332,7 @@ case class DataSource(
caseInsensitiveOptions.get("path").toSeq ++ paths,
sparkSession.sessionState.newHadoopConf()) =>
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
- val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None)
- val fileCatalog = if (userSpecifiedSchema.nonEmpty) {
- val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog)
- new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema))
- } else {
- tempFileCatalog
- }
+ val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema)
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sparkSession,
@@ -381,24 +354,23 @@ case class DataSource(
// This is a non-streaming file based datasource.
case (format: FileFormat, _) =>
- val allPaths = caseInsensitiveOptions.get("path") ++ paths
- val hadoopConf = sparkSession.sessionState.newHadoopConf()
- val globbedPaths = allPaths.flatMap(
- DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray
-
- val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
- val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache)
-
- val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions &&
- catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) {
+ val globbedPaths =
+ checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist)
+ val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions &&
+ catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog &&
+ catalogTable.get.partitionColumnNames.nonEmpty
+ val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) {
val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes
- new CatalogFileIndex(
+ val index = new CatalogFileIndex(
sparkSession,
catalogTable.get,
catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize))
+ (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema)
} else {
- new InMemoryFileIndex(
- sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache)
+ val index = createInMemoryFileIndex(globbedPaths)
+ val (resultDataSchema, resultPartitionSchema) =
+ getOrInferFileFormatSchema(format, Some(index))
+ (index, resultDataSchema, resultPartitionSchema)
}
HadoopFsRelation(
@@ -435,10 +407,11 @@ case class DataSource(
}
/**
- * Writes the given [[LogicalPlan]] out in this [[FileFormat]].
+ * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]].
+ * The returned command is unresolved and need to be analyzed.
*/
private def planForWritingFileFormat(
- format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = {
+ format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = {
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
// 2. Output path must be a legal HDFS style file system path;
@@ -482,9 +455,24 @@ case class DataSource(
/**
* Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for
* the following reading.
+ *
+ * @param mode The save mode for this writing.
+ * @param data The input query plan that produces the data to be written. Note that this plan
+ * is analyzed and optimized.
+ * @param outputColumns The original output columns of the input query plan. The optimizer may not
+ * preserve the output column's names' case, so we need this parameter
+ * instead of `data.output`.
+ * @param physicalPlan The physical plan of the input query plan. We should run the writing
+ * command with this physical plan instead of creating a new physical plan,
+ * so that the metrics can be correctly linked to the given physical plan and
+ * shown in the web UI.
*/
- def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = {
- if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
+ def writeAndRead(
+ mode: SaveMode,
+ data: LogicalPlan,
+ outputColumns: Seq[Attribute],
+ physicalPlan: SparkPlan): BaseRelation = {
+ if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
@@ -493,9 +481,23 @@ case class DataSource(
dataSource.createRelation(
sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data))
case format: FileFormat =>
- sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd
+ val cmd = planForWritingFileFormat(format, mode, data)
+ val resolvedPartCols = cmd.partitionColumns.map { col =>
+ // The partition columns created in `planForWritingFileFormat` should always be
+ // `UnresolvedAttribute` with a single name part.
+ assert(col.isInstanceOf[UnresolvedAttribute])
+ val unresolved = col.asInstanceOf[UnresolvedAttribute]
+ assert(unresolved.nameParts.length == 1)
+ val name = unresolved.nameParts.head
+ outputColumns.find(a => equality(a.name, name)).getOrElse {
+ throw new AnalysisException(
+ s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
+ }
+ }
+ val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns)
+ resolved.run(sparkSession, physicalPlan)
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
- copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
+ copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
@@ -513,11 +515,46 @@ case class DataSource(
case dataSource: CreatableRelationProvider =>
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
case format: FileFormat =>
+ DataSource.validateSchema(data.schema)
planForWritingFileFormat(format, mode, data)
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
}
+
+ /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */
+ private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = {
+ val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
+ new InMemoryFileIndex(
+ sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache)
+ }
+
+ /**
+ * Checks and returns files in all the paths.
+ */
+ private def checkAndGlobPathIfNecessary(
+ checkEmptyGlobPath: Boolean,
+ checkFilesExist: Boolean): Seq[Path] = {
+ val allPaths = caseInsensitiveOptions.get("path") ++ paths
+ val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ allPaths.flatMap { path =>
+ val hdfsPath = new Path(path)
+ val fs = hdfsPath.getFileSystem(hadoopConf)
+ val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified)
+
+ if (checkEmptyGlobPath && globPath.isEmpty) {
+ throw new AnalysisException(s"Path does not exist: $qualified")
+ }
+
+ // Sufficient to check head of the globPath seq for non-glob scenario
+ // Don't need to check once again if files exist in streaming mode
+ if (checkFilesExist && !fs.exists(globPath.head)) {
+ throw new AnalysisException(s"Path does not exist: ${globPath.head}")
+ }
+ globPath
+ }.toSeq
+ }
}
object DataSource extends Logging {
@@ -531,6 +568,8 @@ object DataSource extends Logging {
val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat"
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
+ val socket = classOf[TextSocketSourceProvider].getCanonicalName
+ val rate = classOf[RateStreamProvider].getCanonicalName
Map(
"org.apache.spark.sql.jdbc" -> jdbc,
@@ -551,7 +590,9 @@ object DataSource extends Logging {
"org.apache.spark.sql.execution.datasources.orc" -> nativeOrc,
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
- "com.databricks.spark.csv" -> csv
+ "com.databricks.spark.csv" -> csv,
+ "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
+ "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
)
}
@@ -662,26 +703,25 @@ object DataSource extends Logging {
}
/**
- * If `path` is a file pattern, return all the files that match it. Otherwise, return itself.
- * If `checkFilesExist` is `true`, also check the file existence.
+ * Called before writing into a FileFormat based data source to make sure the
+ * supplied schema is not empty.
+ * @param schema
*/
- private def checkAndGlobPathIfNecessary(
- hadoopConf: Configuration,
- path: String,
- checkFilesExist: Boolean): Seq[Path] = {
- val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(hadoopConf)
- val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified)
-
- if (globPath.isEmpty) {
- throw new AnalysisException(s"Path does not exist: $qualified")
+ private def validateSchema(schema: StructType): Unit = {
+ def hasEmptySchema(schema: StructType): Boolean = {
+ schema.size == 0 || schema.find {
+ case StructField(_, b: StructType, _, _) => hasEmptySchema(b)
+ case _ => false
+ }.isDefined
}
- // Sufficient to check head of the globPath seq for non-glob scenario
- // Don't need to check once again if files exist in streaming mode
- if (checkFilesExist && !fs.exists(globPath.head)) {
- throw new AnalysisException(s"Path does not exist: ${globPath.head}")
+
+
+ if (hasEmptySchema(schema)) {
+ throw new AnalysisException(
+ s"""
+ |Datasource does not support writing empty or nested empty schemas.
+ |Please make sure the data schema has at least one or more column(s).
+ """.stripMargin)
}
- globPath
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index d94c5bbccdd84..7b129435c45db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
case CreateTable(tableDesc, mode, Some(query))
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema))
- CreateDataSourceTableAsSelectCommand(tableDesc, mode, query)
+ CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output)
case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
@@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
case _ => Nil
}
- // Get the bucket ID based on the bucketing values.
- // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
- def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
- val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
- mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
- val bucketIdGeneration = UnsafeProjection.create(
- HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
- bucketColumn :: Nil)
-
- bucketIdGeneration(mutableRow).getInt(0)
- }
-
// Based on Public API.
private def pruneFilterProject(
relation: LogicalRelation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
new file mode 100644
index 0000000000000..6499328e89ce7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * Abstract class for writing out data in a single Spark task.
+ * Exceptions thrown by the implementation of this trait will automatically trigger task aborts.
+ */
+abstract class FileFormatDataWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol) {
+ /**
+ * Max number of files a single task writes out due to file size. In most cases the number of
+ * files written should be very small. This is just a safe guard to protect some really bad
+ * settings, e.g. maxRecordsPerFile = 1.
+ */
+ protected val MAX_FILE_COUNTER: Int = 1000 * 1000
+ protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]()
+ protected var currentWriter: OutputWriter = _
+
+ /** Trackers for computing various statistics on the data as it's being written out. */
+ protected val statsTrackers: Seq[WriteTaskStatsTracker] =
+ description.statsTrackers.map(_.newTaskInstance())
+
+ protected def releaseResources(): Unit = {
+ if (currentWriter != null) {
+ try {
+ currentWriter.close()
+ } finally {
+ currentWriter = null
+ }
+ }
+ }
+
+ /** Writes a record */
+ def write(record: InternalRow): Unit
+
+ /**
+ * Returns the summary of relative information which
+ * includes the list of partition strings written out. The list of partitions is sent back
+ * to the driver and used to update the catalog. Other information will be sent back to the
+ * driver too and used to e.g. update the metrics in UI.
+ */
+ def commit(): WriteTaskResult = {
+ releaseResources()
+ val summary = ExecutedWriteSummary(
+ updatedPartitions = updatedPartitions.toSet,
+ stats = statsTrackers.map(_.getFinalStats()))
+ WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
+ }
+
+ def abort(): Unit = {
+ try {
+ releaseResources()
+ } finally {
+ committer.abortTask(taskAttemptContext)
+ }
+ }
+}
+
+/** FileFormatWriteTask for empty partitions */
+class EmptyDirectoryDataWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol
+) extends FileFormatDataWriter(description, taskAttemptContext, committer) {
+ override def write(record: InternalRow): Unit = {}
+}
+
+/** Writes data to a single directory (used for non-dynamic-partition writes). */
+class SingleDirectoryDataWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol)
+ extends FileFormatDataWriter(description, taskAttemptContext, committer) {
+ private var fileCounter: Int = _
+ private var recordsInFile: Long = _
+ // Initialize currentWriter and statsTrackers
+ newOutputWriter()
+
+ private def newOutputWriter(): Unit = {
+ recordsInFile = 0
+ releaseResources()
+
+ val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
+ val currentPath = committer.newTaskTempFile(
+ taskAttemptContext,
+ None,
+ f"-c$fileCounter%03d" + ext)
+
+ currentWriter = description.outputWriterFactory.newInstance(
+ path = currentPath,
+ dataSchema = description.dataColumns.toStructType,
+ context = taskAttemptContext)
+
+ statsTrackers.foreach(_.newFile(currentPath))
+ }
+
+ override def write(record: InternalRow): Unit = {
+ if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
+ fileCounter += 1
+ assert(fileCounter < MAX_FILE_COUNTER,
+ s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+ newOutputWriter()
+ }
+
+ currentWriter.write(record)
+ statsTrackers.foreach(_.newRow(record))
+ recordsInFile += 1
+ }
+}
+
+/**
+ * Writes data to using dynamic partition writes, meaning this single function can write to
+ * multiple directories (partitions) or files (bucketing).
+ */
+class DynamicPartitionDataWriter(
+ description: WriteJobDescription,
+ taskAttemptContext: TaskAttemptContext,
+ committer: FileCommitProtocol)
+ extends FileFormatDataWriter(description, taskAttemptContext, committer) {
+
+ /** Flag saying whether or not the data to be written out is partitioned. */
+ private val isPartitioned = description.partitionColumns.nonEmpty
+
+ /** Flag saying whether or not the data to be written out is bucketed. */
+ private val isBucketed = description.bucketIdExpression.isDefined
+
+ assert(isPartitioned || isBucketed,
+ s"""DynamicPartitionWriteTask should be used for writing out data that's either
+ |partitioned or bucketed. In this case neither is true.
+ |WriteJobDescription: $description
+ """.stripMargin)
+
+ private var fileCounter: Int = _
+ private var recordsInFile: Long = _
+ private var currentPartionValues: Option[UnsafeRow] = None
+ private var currentBucketId: Option[Int] = None
+
+ /** Extracts the partition values out of an input row. */
+ private lazy val getPartitionValues: InternalRow => UnsafeRow = {
+ val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns)
+ row => proj(row)
+ }
+
+ /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
+ private lazy val partitionPathExpression: Expression = Concat(
+ description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+ val partitionName = ScalaUDF(
+ ExternalCatalogUtils.getPartitionPathString _,
+ StringType,
+ Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))))
+ if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
+ })
+
+ /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
+ * the partition string. */
+ private lazy val getPartitionPath: InternalRow => String = {
+ val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns)
+ row => proj(row).getString(0)
+ }
+
+ /** Given an input row, returns the corresponding `bucketId` */
+ private lazy val getBucketId: InternalRow => Int = {
+ val proj =
+ UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns)
+ row => proj(row).getInt(0)
+ }
+
+ /** Returns the data columns to be written given an input row */
+ private val getOutputRow =
+ UnsafeProjection.create(description.dataColumns, description.allColumns)
+
+ /**
+ * Opens a new OutputWriter given a partition key and/or a bucket id.
+ * If bucket id is specified, we will append it to the end of the file name, but before the
+ * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
+ *
+ * @param partitionValues the partition which all tuples being written by this `OutputWriter`
+ * belong to
+ * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
+ */
+ private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = {
+ recordsInFile = 0
+ releaseResources()
+
+ val partDir = partitionValues.map(getPartitionPath(_))
+ partDir.foreach(updatedPartitions.add)
+
+ val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
+
+ // This must be in a form that matches our bucketing format. See BucketingUtils.
+ val ext = f"$bucketIdStr.c$fileCounter%03d" +
+ description.outputWriterFactory.getFileExtension(taskAttemptContext)
+
+ val customPath = partDir.flatMap { dir =>
+ description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+ }
+ val currentPath = if (customPath.isDefined) {
+ committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+ } else {
+ committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+ }
+
+ currentWriter = description.outputWriterFactory.newInstance(
+ path = currentPath,
+ dataSchema = description.dataColumns.toStructType,
+ context = taskAttemptContext)
+
+ statsTrackers.foreach(_.newFile(currentPath))
+ }
+
+ override def write(record: InternalRow): Unit = {
+ val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None
+ val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None
+
+ if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
+ // See a new partition or bucket - write to a new partition dir (or a new bucket file).
+ if (isPartitioned && currentPartionValues != nextPartitionValues) {
+ currentPartionValues = Some(nextPartitionValues.get.copy())
+ statsTrackers.foreach(_.newPartition(currentPartionValues.get))
+ }
+ if (isBucketed) {
+ currentBucketId = nextBucketId
+ statsTrackers.foreach(_.newBucket(currentBucketId.get))
+ }
+
+ fileCounter = 0
+ newOutputWriter(currentPartionValues, currentBucketId)
+ } else if (description.maxRecordsPerFile > 0 &&
+ recordsInFile >= description.maxRecordsPerFile) {
+ // Exceeded the threshold in terms of the number of records per file.
+ // Create a new file by increasing the file counter.
+ fileCounter += 1
+ assert(fileCounter < MAX_FILE_COUNTER,
+ s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
+
+ newOutputWriter(currentPartionValues, currentBucketId)
+ }
+ val outputRow = getOutputRow(record)
+ currentWriter.write(outputRow)
+ statsTrackers.foreach(_.newRow(outputRow))
+ recordsInFile += 1
+ }
+}
+
+/** A shared job description for all the write tasks. */
+class WriteJobDescription(
+ val uuid: String, // prevent collision between different (appending) write jobs
+ val serializableHadoopConf: SerializableConfiguration,
+ val outputWriterFactory: OutputWriterFactory,
+ val allColumns: Seq[Attribute],
+ val dataColumns: Seq[Attribute],
+ val partitionColumns: Seq[Attribute],
+ val bucketIdExpression: Option[Expression],
+ val path: String,
+ val customPartitionLocations: Map[TablePartitionSpec, String],
+ val maxRecordsPerFile: Long,
+ val timeZoneId: String,
+ val statsTrackers: Seq[WriteJobStatsTracker])
+ extends Serializable {
+
+ assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
+ s"""
+ |All columns: ${allColumns.mkString(", ")}
+ |Partition columns: ${partitionColumns.mkString(", ")}
+ |Data columns: ${dataColumns.mkString(", ")}
+ """.stripMargin)
+}
+
+/** The result of a successful write task. */
+case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
+
+/**
+ * Wrapper class for the metrics of writing data out.
+ *
+ * @param updatedPartitions the partitions updated during writing data out. Only valid
+ * for dynamic partition.
+ * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
+ */
+case class ExecutedWriteSummary(
+ updatedPartitions: Set[String],
+ stats: Seq[WriteTaskStats])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 1d80a69bc5a1d..52da8356ab835 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
-import scala.collection.mutable
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
@@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils}
-import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
-import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
-
- /**
- * Max number of files a single task writes out due to file size. In most cases the number of
- * files written should be very small. This is just a safe guard to protect some really bad
- * settings, e.g. maxRecordsPerFile = 1.
- */
- private val MAX_FILE_COUNTER = 1000 * 1000
-
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
- outputPath: String,
- customPartitionLocations: Map[TablePartitionSpec, String],
- outputColumns: Seq[Attribute])
-
- /** A shared job description for all the write tasks. */
- private class WriteJobDescription(
- val uuid: String, // prevent collision between different (appending) write jobs
- val serializableHadoopConf: SerializableConfiguration,
- val outputWriterFactory: OutputWriterFactory,
- val allColumns: Seq[Attribute],
- val dataColumns: Seq[Attribute],
- val partitionColumns: Seq[Attribute],
- val bucketIdExpression: Option[Expression],
- val path: String,
- val customPartitionLocations: Map[TablePartitionSpec, String],
- val maxRecordsPerFile: Long,
- val timeZoneId: String,
- val statsTrackers: Seq[WriteJobStatsTracker])
- extends Serializable {
-
- assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
- s"""
- |All columns: ${allColumns.mkString(", ")}
- |Partition columns: ${partitionColumns.mkString(", ")}
- |Data columns: ${dataColumns.mkString(", ")}
- """.stripMargin)
- }
-
- /** The result of a successful write task. */
- private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
+ outputPath: String,
+ customPartitionLocations: Map[TablePartitionSpec, String],
+ outputColumns: Seq[Attribute])
/**
* Basic work flow of this command is:
@@ -190,9 +151,18 @@ object FileFormatWriter extends Logging {
global = false,
child = plan).execute()
}
- val ret = new Array[WriteTaskResult](rdd.partitions.length)
+
+ // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
+ // partition rdd to make sure we at least set up one write task to write the metadata.
+ val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
+ sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1)
+ } else {
+ rdd
+ }
+
+ val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length)
sparkSession.sparkContext.runJob(
- rdd,
+ rddWithNonEmptyPartitions,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
@@ -202,7 +172,7 @@ object FileFormatWriter extends Logging {
committer,
iterator = iter)
},
- 0 until rdd.partitions.length,
+ rddWithNonEmptyPartitions.partitions.indices,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
@@ -253,30 +223,27 @@ object FileFormatWriter extends Logging {
committer.setupTask(taskAttemptContext)
- val writeTask =
+ val dataWriter =
if (sparkPartitionId != 0 && !iterator.hasNext) {
// In case of empty job, leave first partition to save meta for file format like parquet.
- new EmptyDirectoryWriteTask(description)
+ new EmptyDirectoryDataWriter(description, taskAttemptContext, committer)
} else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
- new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
+ new SingleDirectoryDataWriter(description, taskAttemptContext, committer)
} else {
- new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
+ new DynamicPartitionDataWriter(description, taskAttemptContext, committer)
}
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out and commit the task.
- val summary = writeTask.execute(iterator)
- writeTask.releaseResources()
- WriteTaskResult(committer.commitTask(taskAttemptContext), summary)
- })(catchBlock = {
- // If there is an error, release resource and then abort the task
- try {
- writeTask.releaseResources()
- } finally {
- committer.abortTask(taskAttemptContext)
- logError(s"Job $jobId aborted.")
+ while (iterator.hasNext) {
+ dataWriter.write(iterator.next())
}
+ dataWriter.commit()
+ })(catchBlock = {
+ // If there is an error, abort the task
+ dataWriter.abort()
+ logError(s"Job $jobId aborted.")
})
} catch {
case e: FetchFailedException =>
@@ -293,7 +260,7 @@ object FileFormatWriter extends Logging {
private def processStats(
statsTrackers: Seq[WriteJobStatsTracker],
statsPerTask: Seq[Seq[WriteTaskStats]])
- : Unit = {
+ : Unit = {
val numStatsTrackers = statsTrackers.length
assert(statsPerTask.forall(_.length == numStatsTrackers),
@@ -312,281 +279,4 @@ object FileFormatWriter extends Logging {
case (statsTracker, stats) => statsTracker.processStats(stats)
}
}
-
- /**
- * A simple trait for writing out data in a single Spark task, without any concerns about how
- * to commit or abort tasks. Exceptions thrown by the implementation of this trait will
- * automatically trigger task aborts.
- */
- private trait ExecuteWriteTask {
-
- /**
- * Writes data out to files, and then returns the summary of relative information which
- * includes the list of partition strings written out. The list of partitions is sent back
- * to the driver and used to update the catalog. Other information will be sent back to the
- * driver too and used to e.g. update the metrics in UI.
- */
- def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary
- def releaseResources(): Unit
- }
-
- /** ExecuteWriteTask for empty partitions */
- private class EmptyDirectoryWriteTask(description: WriteJobDescription)
- extends ExecuteWriteTask {
-
- val statsTrackers: Seq[WriteTaskStatsTracker] =
- description.statsTrackers.map(_.newTaskInstance())
-
- override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
- ExecutedWriteSummary(
- updatedPartitions = Set.empty,
- stats = statsTrackers.map(_.getFinalStats()))
- }
-
- override def releaseResources(): Unit = {}
- }
-
- /** Writes data to a single directory (used for non-dynamic-partition writes). */
- private class SingleDirectoryWriteTask(
- description: WriteJobDescription,
- taskAttemptContext: TaskAttemptContext,
- committer: FileCommitProtocol) extends ExecuteWriteTask {
-
- private[this] var currentWriter: OutputWriter = _
-
- val statsTrackers: Seq[WriteTaskStatsTracker] =
- description.statsTrackers.map(_.newTaskInstance())
-
- private def newOutputWriter(fileCounter: Int): Unit = {
- val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext)
- val currentPath = committer.newTaskTempFile(
- taskAttemptContext,
- None,
- f"-c$fileCounter%03d" + ext)
-
- currentWriter = description.outputWriterFactory.newInstance(
- path = currentPath,
- dataSchema = description.dataColumns.toStructType,
- context = taskAttemptContext)
-
- statsTrackers.map(_.newFile(currentPath))
- }
-
- override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
- var fileCounter = 0
- var recordsInFile: Long = 0L
- newOutputWriter(fileCounter)
-
- while (iter.hasNext) {
- if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) {
- fileCounter += 1
- assert(fileCounter < MAX_FILE_COUNTER,
- s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
-
- recordsInFile = 0
- releaseResources()
- newOutputWriter(fileCounter)
- }
-
- val internalRow = iter.next()
- currentWriter.write(internalRow)
- statsTrackers.foreach(_.newRow(internalRow))
- recordsInFile += 1
- }
- releaseResources()
- ExecutedWriteSummary(
- updatedPartitions = Set.empty,
- stats = statsTrackers.map(_.getFinalStats()))
- }
-
- override def releaseResources(): Unit = {
- if (currentWriter != null) {
- try {
- currentWriter.close()
- } finally {
- currentWriter = null
- }
- }
- }
- }
-
- /**
- * Writes data to using dynamic partition writes, meaning this single function can write to
- * multiple directories (partitions) or files (bucketing).
- */
- private class DynamicPartitionWriteTask(
- desc: WriteJobDescription,
- taskAttemptContext: TaskAttemptContext,
- committer: FileCommitProtocol) extends ExecuteWriteTask {
-
- /** Flag saying whether or not the data to be written out is partitioned. */
- val isPartitioned = desc.partitionColumns.nonEmpty
-
- /** Flag saying whether or not the data to be written out is bucketed. */
- val isBucketed = desc.bucketIdExpression.isDefined
-
- assert(isPartitioned || isBucketed,
- s"""DynamicPartitionWriteTask should be used for writing out data that's either
- |partitioned or bucketed. In this case neither is true.
- |WriteJobDescription: ${desc}
- """.stripMargin)
-
- // currentWriter is initialized whenever we see a new key (partitionValues + BucketId)
- private var currentWriter: OutputWriter = _
-
- /** Trackers for computing various statistics on the data as it's being written out. */
- private val statsTrackers: Seq[WriteTaskStatsTracker] =
- desc.statsTrackers.map(_.newTaskInstance())
-
- /** Extracts the partition values out of an input row. */
- private lazy val getPartitionValues: InternalRow => UnsafeRow = {
- val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns)
- row => proj(row)
- }
-
- /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */
- private lazy val partitionPathExpression: Expression = Concat(
- desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
- val partitionName = ScalaUDF(
- ExternalCatalogUtils.getPartitionPathString _,
- StringType,
- Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId))))
- if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
- })
-
- /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns
- * the partition string. */
- private lazy val getPartitionPath: InternalRow => String = {
- val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns)
- row => proj(row).getString(0)
- }
-
- /** Given an input row, returns the corresponding `bucketId` */
- private lazy val getBucketId: InternalRow => Int = {
- val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns)
- row => proj(row).getInt(0)
- }
-
- /** Returns the data columns to be written given an input row */
- private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
-
- /**
- * Opens a new OutputWriter given a partition key and/or a bucket id.
- * If bucket id is specified, we will append it to the end of the file name, but before the
- * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
- *
- * @param partitionValues the partition which all tuples being written by this `OutputWriter`
- * belong to
- * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to
- * @param fileCounter the number of files that have been written in the past for this specific
- * partition. This is used to limit the max number of records written for a
- * single file. The value should start from 0.
- * @param updatedPartitions the set of updated partition paths, we should add the new partition
- * path of this writer to it.
- */
- private def newOutputWriter(
- partitionValues: Option[InternalRow],
- bucketId: Option[Int],
- fileCounter: Int,
- updatedPartitions: mutable.Set[String]): Unit = {
-
- val partDir = partitionValues.map(getPartitionPath(_))
- partDir.foreach(updatedPartitions.add)
-
- val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
-
- // This must be in a form that matches our bucketing format. See BucketingUtils.
- val ext = f"$bucketIdStr.c$fileCounter%03d" +
- desc.outputWriterFactory.getFileExtension(taskAttemptContext)
-
- val customPath = partDir.flatMap { dir =>
- desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
- }
- val currentPath = if (customPath.isDefined) {
- committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
- } else {
- committer.newTaskTempFile(taskAttemptContext, partDir, ext)
- }
-
- currentWriter = desc.outputWriterFactory.newInstance(
- path = currentPath,
- dataSchema = desc.dataColumns.toStructType,
- context = taskAttemptContext)
-
- statsTrackers.foreach(_.newFile(currentPath))
- }
-
- override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = {
- // If anything below fails, we should abort the task.
- var recordsInFile: Long = 0L
- var fileCounter = 0
- val updatedPartitions = mutable.Set[String]()
- var currentPartionValues: Option[UnsafeRow] = None
- var currentBucketId: Option[Int] = None
-
- for (row <- iter) {
- val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
- val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None
-
- if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
- // See a new partition or bucket - write to a new partition dir (or a new bucket file).
- if (isPartitioned && currentPartionValues != nextPartitionValues) {
- currentPartionValues = Some(nextPartitionValues.get.copy())
- statsTrackers.foreach(_.newPartition(currentPartionValues.get))
- }
- if (isBucketed) {
- currentBucketId = nextBucketId
- statsTrackers.foreach(_.newBucket(currentBucketId.get))
- }
-
- recordsInFile = 0
- fileCounter = 0
-
- releaseResources()
- newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
- } else if (desc.maxRecordsPerFile > 0 &&
- recordsInFile >= desc.maxRecordsPerFile) {
- // Exceeded the threshold in terms of the number of records per file.
- // Create a new file by increasing the file counter.
- recordsInFile = 0
- fileCounter += 1
- assert(fileCounter < MAX_FILE_COUNTER,
- s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
-
- releaseResources()
- newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
- }
- val outputRow = getOutputRow(row)
- currentWriter.write(outputRow)
- statsTrackers.foreach(_.newRow(outputRow))
- recordsInFile += 1
- }
- releaseResources()
-
- ExecutedWriteSummary(
- updatedPartitions = updatedPartitions.toSet,
- stats = statsTrackers.map(_.getFinalStats()))
- }
-
- override def releaseResources(): Unit = {
- if (currentWriter != null) {
- try {
- currentWriter.close()
- } finally {
- currentWriter = null
- }
- }
- }
- }
}
-
-/**
- * Wrapper class for the metrics of writing data out.
- *
- * @param updatedPartitions the partitions updated during writing data out. Only valid
- * for dynamic partition.
- * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had.
- */
-case class ExecutedWriteSummary(
- updatedPartitions: Set[String],
- stats: Seq[WriteTaskStats])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 835ce98462477..28c36b6020d33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -21,11 +21,14 @@ import java.io.{FileNotFoundException, IOException}
import scala.collection.mutable
+import org.apache.parquet.io.ParquetDecodingException
+
import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{InputFileBlockHolder, RDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.NextIterator
@@ -179,7 +182,23 @@ class FileScanRDD(
currentIterator = readCurrentFile()
}
- hasNext
+ try {
+ hasNext
+ } catch {
+ case e: SchemaColumnConvertNotSupportedException =>
+ val message = "Parquet column cannot be converted in " +
+ s"file ${currentFile.filePath}. Column: ${e.getColumn}, " +
+ s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}"
+ throw new QueryExecutionException(message, e)
+ case e: ParquetDecodingException =>
+ if (e.getMessage.contains("Can not read value at")) {
+ val message = "Encounter error while reading parquet files. " +
+ "One possible cause: Parquet column cannot be converted in the " +
+ "corresponding files. Details: "
+ throw new QueryExecutionException(message, e)
+ }
+ throw e
+ }
} else {
currentFile = null
InputFileBlockHolder.unset()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 16b22717b8d92..fe27b78bf3360 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
+import org.apache.spark.util.collection.BitSet
/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan
* and add it. Proceed to the next file.
*/
object FileSourceStrategy extends Strategy with Logging {
+
+ // should prune buckets iff num buckets is greater than 1 and there is only one bucket column
+ private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
+ bucketSpec match {
+ case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1
+ case None => false
+ }
+ }
+
+ private def getExpressionBuckets(
+ expr: Expression,
+ bucketColumnName: String,
+ numBuckets: Int): BitSet = {
+
+ def getBucketNumber(attr: Attribute, v: Any): Int = {
+ BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)
+ }
+
+ def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = {
+ val matchedBuckets = new BitSet(numBuckets)
+ iter
+ .map(v => getBucketNumber(attr, v))
+ .foreach(bucketNum => matchedBuckets.set(bucketNum))
+ matchedBuckets
+ }
+
+ def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = {
+ val matchedBuckets = new BitSet(numBuckets)
+ matchedBuckets.set(getBucketNumber(attr, v))
+ matchedBuckets
+ }
+
+ expr match {
+ case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+ getBucketSetFromValue(a, v)
+ case expressions.In(a: Attribute, list)
+ if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
+ getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow)))
+ case expressions.InSet(a: Attribute, hset)
+ if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
+ getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow)))
+ case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
+ getBucketSetFromValue(a, null)
+ case expressions.And(left, right) =>
+ getExpressionBuckets(left, bucketColumnName, numBuckets) &
+ getExpressionBuckets(right, bucketColumnName, numBuckets)
+ case expressions.Or(left, right) =>
+ getExpressionBuckets(left, bucketColumnName, numBuckets) |
+ getExpressionBuckets(right, bucketColumnName, numBuckets)
+ case _ =>
+ val matchedBuckets = new BitSet(numBuckets)
+ matchedBuckets.setUntil(numBuckets)
+ matchedBuckets
+ }
+ }
+
+ private def genBucketSet(
+ normalizedFilters: Seq[Expression],
+ bucketSpec: BucketSpec): Option[BitSet] = {
+ if (normalizedFilters.isEmpty) {
+ return None
+ }
+
+ val bucketColumnName = bucketSpec.bucketColumnNames.head
+ val numBuckets = bucketSpec.numBuckets
+
+ val normalizedFiltersAndExpr = normalizedFilters
+ .reduce(expressions.And)
+ val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName,
+ numBuckets)
+
+ val numBucketsSelected = matchedBuckets.cardinality()
+
+ logInfo {
+ s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets."
+ }
+
+ // None means all the buckets need to be scanned
+ if (numBucketsSelected == numBuckets) {
+ None
+ } else {
+ Some(matchedBuckets)
+ }
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
@@ -76,9 +162,19 @@ object FileSourceStrategy extends Strategy with Logging {
fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
- ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
+ ExpressionSet(normalizedFilters
+ .filterNot(SubqueryExpression.hasSubquery(_))
+ .filter(_.references.subsetOf(partitionSet)))
+
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
+ val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec
+ val bucketSet = if (shouldPruneBuckets(bucketSpec)) {
+ genBucketSet(normalizedFilters, bucketSpec.get)
+ } else {
+ None
+ }
+
val dataColumns =
l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)
@@ -108,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging {
outputAttributes,
outputSchema,
partitionKeyFilters.toSeq,
+ bucketSet,
dataFilters,
table.map(_.identifier))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
index 83cf26c63a175..00a78f7343c59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
@@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
/**
* An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines
* in that file.
+ *
+ * @param file A part (i.e. "block") of a single file that should be read line by line.
+ * @param lineSeparator A line separator that should be used for each line. If the value is `None`,
+ * it covers `\r`, `\r\n` and `\n`.
+ * @param conf Hadoop configuration
+ *
+ * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined
+ * by [[LineRecordReader]], not within Spark.
*/
class HadoopFileLinesReader(
- file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable {
+ file: PartitionedFile,
+ lineSeparator: Option[Array[Byte]],
+ conf: Configuration) extends Iterator[Text] with Closeable {
+
+ def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf)
+
private val iterator = {
val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
@@ -42,7 +55,13 @@ class HadoopFileLinesReader(
Array.empty)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
- val reader = new LineRecordReader()
+
+ val reader = lineSeparator match {
+ case Some(sep) => new LineRecordReader(sep)
+ // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`.
+ case _ => new LineRecordReader()
+ }
+
reader.initialize(fileSplit, hadoopAttemptContext)
new RecordReaderIterator(reader)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
index 6b34638529770..b2f73b7f8d1fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala
@@ -67,6 +67,9 @@ case class HadoopFsRelation(
}
}
+ // When data and partition schemas have overlapping columns, the output
+ // schema respects the order of the data schema for the overlapping columns, and it
+ // respects the data types of the partition schema.
val schema: StructType = {
StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
index 318ada0ceefc5..9d9f8bd5bb58e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
@@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration
* @param rootPathsSpecified the list of root table paths to scan (some of which might be
* filtered out later)
* @param parameters as set of options to control discovery
- * @param partitionSchema an optional partition schema that will be use to provide types for the
- * discovered partitions
+ * @param userSpecifiedSchema an optional user specified schema that will be use to provide
+ * types for the discovered partitions
*/
class InMemoryFileIndex(
sparkSession: SparkSession,
rootPathsSpecified: Seq[Path],
parameters: Map[String, String],
- partitionSchema: Option[StructType],
+ userSpecifiedSchema: Option[StructType],
fileStatusCache: FileStatusCache = NoopCache)
extends PartitioningAwareFileIndex(
- sparkSession, parameters, partitionSchema, fileStatusCache) {
+ sparkSession, parameters, userSpecifiedSchema, fileStatusCache) {
// Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir)
// or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain
@@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging {
if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles
}
- allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map {
+ val missingFiles = mutable.ArrayBuffer.empty[String]
+ val filteredLeafStatuses = allLeafStatuses.filterNot(
+ status => shouldFilterOut(status.getPath.getName))
+ val resolvedLeafStatuses = filteredLeafStatuses.flatMap {
case f: LocatedFileStatus =>
- f
+ Some(f)
// NOTE:
//
@@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging {
// The other constructor of LocatedFileStatus will call FileStatus.getPermission(),
// which is very slow on some file system (RawLocalFileSystem, which is launch a
// subprocess and parse the stdout).
- val locations = fs.getFileBlockLocations(f, 0, f.getLen)
- val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize,
- f.getModificationTime, 0, null, null, null, null, f.getPath, locations)
- if (f.isSymlink) {
- lfs.setSymlink(f.getSymlink)
+ try {
+ val locations = fs.getFileBlockLocations(f, 0, f.getLen)
+ val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize,
+ f.getModificationTime, 0, null, null, null, null, f.getPath, locations)
+ if (f.isSymlink) {
+ lfs.setSymlink(f.getSymlink)
+ }
+ Some(lfs)
+ } catch {
+ case _: FileNotFoundException =>
+ missingFiles += f.getPath.toString
+ None
}
- lfs
}
+
+ if (missingFiles.nonEmpty) {
+ logWarning(
+ s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}")
+ }
+
+ resolvedLeafStatuses
}
/** Checks if we should filter out this path name. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index 6b6f6388d54e8..cc8af7b92c454 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType}
* It provides the necessary methods to parse partition data based on a set of files.
*
* @param parameters as set of options to control partition discovery
- * @param userPartitionSchema an optional partition schema that will be use to provide types for
- * the discovered partitions
+ * @param userSpecifiedSchema an optional user specified schema that will be use to provide
+ * types for the discovered partitions
*/
abstract class PartitioningAwareFileIndex(
sparkSession: SparkSession,
parameters: Map[String, String],
- userPartitionSchema: Option[StructType],
+ userSpecifiedSchema: Option[StructType],
fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging {
import PartitioningAwareFileIndex.BASE_PATH_PARAM
@@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex(
val caseInsensitiveOptions = CaseInsensitiveMap(parameters)
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
-
- userPartitionSchema match {
+ val inferredPartitionSpec = PartitioningUtils.parsePartitions(
+ leafDirs,
+ typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
+ basePaths = basePaths,
+ timeZoneId = timeZoneId)
+ userSpecifiedSchema match {
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
- val spec = PartitioningUtils.parsePartitions(
- leafDirs,
- typeInference = false,
- basePaths = basePaths,
- timeZoneId = timeZoneId)
+ val userPartitionSchema =
+ combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec)
- // Without auto inference, all of value in the `row` should be null or in StringType,
// we need to cast into the data type that user specified.
def castPartitionValuesToUserSchema(row: InternalRow) = {
InternalRow((0 until row.numFields).map { i =>
+ val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType
Cast(
- Literal.create(row.getUTF8String(i), StringType),
- userProvidedSchema.fields(i).dataType,
+ Literal.create(row.get(i, dt), dt),
+ userPartitionSchema.fields(i).dataType,
Option(timeZoneId)).eval()
}: _*)
}
- PartitionSpec(userProvidedSchema, spec.partitions.map { part =>
+ PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part =>
part.copy(values = castPartitionValuesToUserSchema(part.values))
})
case _ =>
- PartitioningUtils.parsePartitions(
- leafDirs,
- typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
- basePaths = basePaths,
- timeZoneId = timeZoneId)
+ inferredPartitionSpec
}
}
@@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex(
val name = path.getName
!((name.startsWith("_") && !name.contains("=")) || name.startsWith("."))
}
+
+ /**
+ * In the read path, only managed tables by Hive provide the partition columns properly when
+ * initializing this class. All other file based data sources will try to infer the partitioning,
+ * and then cast the inferred types to user specified dataTypes if the partition columns exist
+ * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
+ * inconsistent data types as reported in SPARK-21463.
+ * @param spec A partition inference result
+ * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
+ */
+ private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = {
+ val equality = sparkSession.sessionState.conf.resolver
+ val resolved = spec.partitionColumns.map { partitionField =>
+ // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
+ userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
+ partitionField)
+ }
+ StructType(resolved)
+ }
}
object PartitioningAwareFileIndex {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 472bf82d3604d..f9a24806953e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -407,6 +407,34 @@ object PartitioningUtils {
Literal(bigDecimal)
}
+ val dateTry = Try {
+ // try and parse the date, if no exception occurs this is a candidate to be resolved as
+ // DateType
+ DateTimeUtils.getThreadLocalDateFormat.parse(raw)
+ // SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
+ // This can happen since DateFormat.parse may not use the entire text of the given string:
+ // so if there are extra-characters after the date, it returns correctly.
+ // We need to check that we can cast the raw string since we later can use Cast to get
+ // the partition values with the right DataType (see
+ // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning)
+ val dateValue = Cast(Literal(raw), DateType).eval()
+ // Disallow DateType if the cast returned null
+ require(dateValue != null)
+ Literal.create(dateValue, DateType)
+ }
+
+ val timestampTry = Try {
+ val unescapedRaw = unescapePathName(raw)
+ // try and parse the date, if no exception occurs this is a candidate to be resolved as
+ // TimestampType
+ DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw)
+ // SPARK-23436: see comment for date
+ val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval()
+ // Disallow TimestampType if the cast returned null
+ require(timestampValue != null)
+ Literal.create(timestampValue, TimestampType)
+ }
+
if (typeInference) {
// First tries integral types
Try(Literal.create(Integer.parseInt(raw), IntegerType))
@@ -415,16 +443,8 @@ object PartitioningUtils {
// Then falls back to fractional types
.orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
// Then falls back to date/timestamp types
- .orElse(Try(
- Literal.create(
- DateTimeUtils.getThreadLocalTimestampFormat(timeZone)
- .parse(unescapePathName(raw)).getTime * 1000L,
- TimestampType)))
- .orElse(Try(
- Literal.create(
- DateTimeUtils.millisToDays(
- DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime),
- DateType)))
+ .orElse(timestampTry)
+ .orElse(dateTry)
// Then falls back to string
.getOrElse {
if (raw == DEFAULT_PARTITION_NAME) {
@@ -466,7 +486,8 @@ object PartitioningUtils {
val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
schema.find(f => equality(f.name, col)).getOrElse {
- throw new AnalysisException(s"Partition column $col not found in schema $schema")
+ val schemaCatalog = schema.catalogString
+ throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog")
}
}).asNullable
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 3b830accb83f0..16b2367bfdd5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
- ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
+ ExpressionSet(normalizedFilters
+ .filterNot(SubqueryExpression.hasSubquery(_))
+ .filter(_.references.subsetOf(partitionSet)))
if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 568e953a5db66..00b1b5dedb593 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.SparkEnv
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider
-import org.apache.spark.util.Utils
/**
* Saves the results of `query` in to a data source.
@@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand(
}
override def simpleString: String = {
- val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap
+ val redacted = SQLConf.get.redactOptions(options)
s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 4870d75fc5f08..82322df407521 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
@@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow]
+ requiredSchema: StructType,
+ // Actual schema of data in the csv file
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow]
/**
* Infers the schema from `inputPaths` files.
@@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable {
}
}
-object CSVDataSource {
+object CSVDataSource extends Logging {
def apply(options: CSVOptions): CSVDataSource = {
if (options.multiLine) {
MultiLineCSVDataSource
@@ -118,6 +122,84 @@ object CSVDataSource {
TextInputCSVDataSource
}
}
+
+ /**
+ * Checks that column names in a CSV header and field names in the schema are the same
+ * by taking into account case sensitivity.
+ *
+ * @param schema - provided (or inferred) schema to which CSV must conform.
+ * @param columnNames - names of CSV columns that must be checked against to the schema.
+ * @param fileName - name of CSV file that are currently checked. It is used in error messages.
+ * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
+ * names are checked for conformance to the schema. In the case if
+ * the column name don't conform to the schema, an exception is thrown.
+ * @param caseSensitive - if it is set to `false`, comparison of column names and schema field
+ * names is not case sensitive.
+ */
+ def checkHeaderColumnNames(
+ schema: StructType,
+ columnNames: Array[String],
+ fileName: String,
+ enforceSchema: Boolean,
+ caseSensitive: Boolean): Unit = {
+ if (columnNames != null) {
+ val fieldNames = schema.map(_.name).toIndexedSeq
+ val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
+ var errorMessage: Option[String] = None
+
+ if (headerLen == schemaSize) {
+ var i = 0
+ while (errorMessage.isEmpty && i < headerLen) {
+ var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
+ if (!caseSensitive) {
+ nameInSchema = nameInSchema.toLowerCase
+ nameInHeader = nameInHeader.toLowerCase
+ }
+ if (nameInHeader != nameInSchema) {
+ errorMessage = Some(
+ s"""|CSV header does not conform to the schema.
+ | Header: ${columnNames.mkString(", ")}
+ | Schema: ${fieldNames.mkString(", ")}
+ |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
+ |CSV file: $fileName""".stripMargin)
+ }
+ i += 1
+ }
+ } else {
+ errorMessage = Some(
+ s"""|Number of column in CSV header is not equal to number of fields in the schema:
+ | Header length: $headerLen, schema size: $schemaSize
+ |CSV file: $fileName""".stripMargin)
+ }
+
+ errorMessage.foreach { msg =>
+ if (enforceSchema) {
+ logWarning(msg)
+ } else {
+ throw new IllegalArgumentException(msg)
+ }
+ }
+ }
+ }
+
+ /**
+ * Checks that CSV header contains the same column names as fields names in the given schema
+ * by taking into account case sensitivity.
+ */
+ def checkHeader(
+ header: String,
+ parser: CsvParser,
+ schema: StructType,
+ fileName: String,
+ enforceSchema: Boolean,
+ caseSensitive: Boolean): Unit = {
+ checkHeaderColumnNames(
+ schema,
+ parser.parseLine(header),
+ fileName,
+ enforceSchema,
+ caseSensitive)
+ }
}
object TextInputCSVDataSource extends CSVDataSource {
@@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ requiredSchema: StructType,
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}
- val shouldDropHeader = parser.options.headerFlag && file.start == 0
- UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
+ val hasHeader = parser.options.headerFlag && file.start == 0
+ if (hasHeader) {
+ // Checking that column names in the header are matched to field names of the schema.
+ // The header will be removed from lines.
+ // Note: if there are only comments in the first block, the header would probably
+ // be not extracted.
+ CSVUtils.extractHeader(lines, parser.options).foreach { header =>
+ CSVDataSource.checkHeader(
+ header,
+ parser.tokenizer,
+ dataSchema,
+ file.filePath,
+ parser.options.enforceSchema,
+ caseSensitive)
+ }
+ }
+
+ UnivocityParser.parseIterator(lines, parser, requiredSchema)
}
override def infer(
@@ -161,7 +261,8 @@ object TextInputCSVDataSource extends CSVDataSource {
val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
- val tokenRDD = csv.rdd.mapPartitions { iter =>
+ val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
+ val tokenRDD = sampled.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
val linesWithoutHeader =
CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
@@ -184,7 +285,8 @@ object TextInputCSVDataSource extends CSVDataSource {
DataSource.apply(
sparkSession,
paths = paths,
- className = classOf[TextFileFormat].getName
+ className = classOf[TextFileFormat].getName,
+ options = options.parameters
).resolveRelation(checkFilesExist = false))
.select("value").as[String](Encoders.STRING)
} else {
@@ -204,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ requiredSchema: StructType,
+ dataSchema: StructType,
+ caseSensitive: Boolean): Iterator[InternalRow] = {
+ def checkHeader(header: Array[String]): Unit = {
+ CSVDataSource.checkHeaderColumnNames(
+ dataSchema,
+ header,
+ file.filePath,
+ parser.options.enforceSchema,
+ caseSensitive)
+ }
+
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
parser.options.headerFlag,
parser,
- schema)
+ requiredSchema,
+ checkHeader)
}
override def infer(
@@ -235,7 +349,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
parsedOptions.headerFlag,
new CsvParser(parsedOptions.asParserSettings))
}
- CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+ val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
+ CSVInferSchema.infer(sampled, header, parsedOptions)
case None =>
// If the first row could not be read, just return the empty schema.
StructType(Nil)
@@ -248,7 +363,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
options: CSVOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
val name = paths.mkString(",")
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
+ options.parameters))
FileInputFormat.setInputPaths(job, paths: _*)
val conf = job.getConfiguration
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index e20977a4ec79f..b90275de9f40a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -41,8 +41,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
- val parsedOptions =
- new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val parsedOptions = new CSVOptions(
+ options,
+ columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
val csvDataSource = CSVDataSource(parsedOptions)
csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path)
}
@@ -51,8 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- val parsedOptions =
- new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val parsedOptions = new CSVOptions(
+ options,
+ columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions)
}
@@ -64,7 +68,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
CSVUtils.verifySchema(dataSchema)
val conf = job.getConfiguration
- val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val csvOptions = new CSVOptions(
+ options,
+ columnPruning = sparkSession.sessionState.conf.csvColumnPruning,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
csvOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
@@ -97,6 +104,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val parsedOptions = new CSVOptions(
options,
+ sparkSession.sessionState.conf.csvColumnPruning,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
@@ -122,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
(file: PartitionedFile) => {
val conf = broadcastedHadoopConf.value.value
@@ -129,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
parsedOptions)
- CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
+ CSVDataSource(parsedOptions).readFile(
+ conf,
+ file,
+ parser,
+ requiredSchema,
+ dataSchema,
+ caseSensitive)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index c16790630ce17..fab8d62da0c1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -27,17 +27,20 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
class CSVOptions(
- @transient private val parameters: CaseInsensitiveMap[String],
+ @transient val parameters: CaseInsensitiveMap[String],
+ val columnPruning: Boolean,
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
def this(
parameters: Map[String, String],
+ columnPruning: Boolean,
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String = "") = {
this(
CaseInsensitiveMap(parameters),
+ columnPruning,
defaultTimeZoneId,
defaultColumnNameOfCorruptRecord)
}
@@ -150,6 +153,15 @@ class CSVOptions(
val isCommentSet = this.comment != '\u0000'
+ val samplingRatio =
+ parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
+ /**
+ * Forcibly apply the specified or inferred schema to datasource files.
+ * If the option is enabled, headers of CSV files will be ignored.
+ */
+ val enforceSchema = getBool("enforceSchema", default = true)
+
def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
@@ -161,7 +173,7 @@ class CSVOptions(
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
- writerSettings.setEmptyValue(nullValue)
+ writerSettings.setEmptyValue("\"\"")
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
@@ -182,6 +194,7 @@ class CSVOptions(
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
settings.setNullValue(nullValue)
+ settings.setEmptyValue("")
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 72b053d2092ca..1012e774118e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.csv
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -67,12 +68,8 @@ object CSVUtils {
}
}
- /**
- * Drop header line so that only data can remain.
- * This is similar with `filterHeaderLine` above and currently being used in CSV reading path.
- */
- def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
- val nonEmptyLines = if (options.isCommentSet) {
+ def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ if (options.isCommentSet) {
val commentPrefix = options.comment.toString
iter.dropWhile { line =>
line.trim.isEmpty || line.trim.startsWith(commentPrefix)
@@ -80,11 +77,19 @@ object CSVUtils {
} else {
iter.dropWhile(_.trim.isEmpty)
}
-
- if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
- iter
}
+ /**
+ * Extracts header and moves iterator forward so that only data remains in it
+ */
+ def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = {
+ val nonEmptyLines = skipComments(iter, options)
+ if (nonEmptyLines.hasNext) {
+ Some(nonEmptyLines.next())
+ } else {
+ None
+ }
+ }
/**
* Helper method that converts string representation of a character to actual character.
* It handles some Java escaped strings and throws exception if given string is longer than one
@@ -131,4 +136,29 @@ object CSVUtils {
schema.foreach(field => verifyType(field.dataType))
}
+ /**
+ * Sample CSV dataset as configured by `samplingRatio`.
+ */
+ def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = {
+ require(options.samplingRatio > 0,
+ s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+ if (options.samplingRatio > 0.99) {
+ csv
+ } else {
+ csv.sample(withReplacement = false, options.samplingRatio, 1)
+ }
+ }
+
+ /**
+ * Sample CSV RDD as configured by `samplingRatio`.
+ */
+ def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = {
+ require(options.samplingRatio > 0,
+ s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+ if (options.samplingRatio > 0.99) {
+ csv
+ } else {
+ csv.sample(withReplacement = false, options.samplingRatio, 1)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 7d6d7e7eef926..5f7d5696b71a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv
import java.io.InputStream
import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
import scala.util.Try
import scala.util.control.NonFatal
@@ -36,10 +34,10 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class UnivocityParser(
- schema: StructType,
+ dataSchema: StructType,
requiredSchema: StructType,
val options: CSVOptions) extends Logging {
- require(requiredSchema.toSet.subsetOf(schema.toSet),
+ require(requiredSchema.toSet.subsetOf(dataSchema.toSet),
"requiredSchema should be the subset of schema.")
def this(schema: StructType, options: CSVOptions) = this(schema, schema, options)
@@ -47,9 +45,17 @@ class UnivocityParser(
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any
- private val tokenizer = new CsvParser(options.asParserSettings)
+ val tokenizer = {
+ val parserSetting = options.asParserSettings
+ if (options.columnPruning && requiredSchema.length < dataSchema.length) {
+ val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f)))
+ parserSetting.selectIndexes(tokenIndexArr: _*)
+ }
+ new CsvParser(parserSetting)
+ }
+ private val schema = if (options.columnPruning) requiredSchema else dataSchema
- private val row = new GenericInternalRow(requiredSchema.length)
+ private val row = new GenericInternalRow(schema.length)
// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
@@ -75,11 +81,8 @@ class UnivocityParser(
// Each input token is placed in each output row's position by mapping these. In this case,
//
// output row - ["A", 2]
- private val valueConverters: Array[ValueConverter] =
+ private val valueConverters: Array[ValueConverter] = {
schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
-
- private val tokenIndexArr: Array[Int] = {
- requiredSchema.map(f => schema.indexOf(f)).toArray
}
/**
@@ -203,6 +206,8 @@ class UnivocityParser(
case _: BadRecordException => None
}
}
+ // For records with less or more tokens than the schema, tries to return partial results
+ // if possible.
throw BadRecordException(
() => getCurrentInput,
() => getPartialResult(),
@@ -210,14 +215,16 @@ class UnivocityParser(
} else {
try {
var i = 0
- while (i < requiredSchema.length) {
- val from = tokenIndexArr(i)
- row(i) = valueConverters(from).apply(tokens(from))
+ while (i < schema.length) {
+ row(i) = valueConverters(i).apply(tokens(i))
i += 1
}
row
} catch {
case NonFatal(e) =>
+ // For corrupted records with the number of tokens same as the schema,
+ // CSV reader doesn't support partial results. All fields other than the field
+ // configured by `columnNameOfCorruptRecord` are set to `null`.
throw BadRecordException(() => getCurrentInput, () => None, e)
}
}
@@ -243,14 +250,15 @@ private[csv] object UnivocityParser {
inputStream: InputStream,
shouldDropHeader: Boolean,
parser: UnivocityParser,
- schema: StructType): Iterator[InternalRow] = {
+ schema: StructType,
+ checkHeader: Array[String] => Unit): Iterator[InternalRow] = {
val tokenizer = parser.tokenizer
val safeParser = new FailureSafeParser[Array[String]](
input => Seq(parser.convert(input)),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
- convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
+ convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens =>
safeParser.parse(tokens)
}.flatten
}
@@ -258,11 +266,14 @@ private[csv] object UnivocityParser {
private def convertStream[T](
inputStream: InputStream,
shouldDropHeader: Boolean,
- tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] {
+ tokenizer: CsvParser,
+ checkHeader: Array[String] => Unit = _ => ())(
+ convert: Array[String] => T) = new Iterator[T] {
tokenizer.beginParsing(inputStream)
private var nextRecord = {
if (shouldDropHeader) {
- tokenizer.parseNext()
+ val firstRecord = tokenizer.parseNext()
+ checkHeader(firstRecord)
}
tokenizer.parseNext()
}
@@ -284,21 +295,11 @@ private[csv] object UnivocityParser {
*/
def parseIterator(
lines: Iterator[String],
- shouldDropHeader: Boolean,
parser: UnivocityParser,
schema: StructType): Iterator[InternalRow] = {
val options = parser.options
- val linesWithoutHeader = if (shouldDropHeader) {
- // Note that if there are only comments in the first block, the header would probably
- // be not dropped.
- CSVUtils.dropHeaderLine(lines, options)
- } else {
- lines
- }
-
- val filteredLines: Iterator[String] =
- CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
+ val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
val safeParser = new FailureSafeParser[String](
input => Seq(parser.parse(input)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
index 7a6c0f9fed2f9..1723596de1db2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -32,6 +32,13 @@ import org.apache.spark.util.Utils
*/
object DriverRegistry extends Logging {
+ /**
+ * Load DriverManager first to avoid any race condition between
+ * DriverManager static initialization block and specific driver class's
+ * static initialization block. e.g. PhoenixDriver
+ */
+ DriverManager.getDrivers
+
private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty
def register(className: String): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index b4e5d169066d9..a73a97c06fe5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -89,6 +89,10 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt)
+ // the number of seconds the driver will wait for a Statement object to execute to the given
+ // number of seconds. Zero means there is no limit.
+ val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt
+
// ------------------------------------------------------------
// Optional parameters only for reading
// ------------------------------------------------------------
@@ -160,6 +164,7 @@ object JDBCOptions {
val JDBC_LOWER_BOUND = newOption("lowerBound")
val JDBC_UPPER_BOUND = newOption("upperBound")
val JDBC_NUM_PARTITIONS = newOption("numPartitions")
+ val JDBC_QUERY_TIMEOUT = newOption("queryTimeout")
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 05326210f3242..0bab3689e5d0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -57,6 +57,7 @@ object JDBCRDD extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
try {
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true)
@@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD(
val statement = conn.prepareStatement(sql)
logInfo(s"Executing sessionInitStatement: $sql")
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.execute()
} finally {
statement.close()
@@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD(
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
+ stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index cc506e51bd0c6..f8c5677ea0f2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
- dropTable(conn, options.table)
+ dropTable(conn, options.table, options)
createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index e6dc2fda4eb1b..433443007cfd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -76,6 +76,7 @@ object JdbcUtils extends Logging {
Try {
val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeQuery()
} finally {
statement.close()
@@ -86,9 +87,10 @@ object JdbcUtils extends Logging {
/**
* Drops a table from the JDBC database.
*/
- def dropTable(conn: Connection, table: String): Unit = {
+ def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(s"DROP TABLE $table")
} finally {
statement.close()
@@ -102,6 +104,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(dialect.getTruncateQuery(options.table))
} finally {
statement.close()
@@ -254,6 +257,7 @@ object JdbcUtils extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
Some(getSchema(statement.executeQuery(), dialect))
} catch {
case _: SQLException => None
@@ -596,7 +600,8 @@ object JdbcUtils extends Logging {
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
- isolationLevel: Int): Iterator[Byte] = {
+ isolationLevel: Int,
+ options: JDBCOptions): Iterator[Byte] = {
val conn = getConnection()
var committed = false
@@ -637,6 +642,9 @@ object JdbcUtils extends Logging {
try {
var rowCount = 0
+
+ stmt.setQueryTimeout(options.queryTimeout)
+
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
@@ -819,7 +827,8 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
- getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
+ getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
+ options)
)
}
@@ -841,6 +850,7 @@ object JdbcUtils extends Logging {
val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(sql)
} finally {
statement.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 77e7edc8e7a20..3b6df45e949e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -31,9 +31,10 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
-import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
@@ -92,25 +93,33 @@ object TextInputJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
- val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
+ val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+
inferFromDataset(json, parsedOptions)
}
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
- val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
- JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
+ val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd
+ val rowParser = parsedOptions.encoding.map { enc =>
+ CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
+ }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
+
+ SQLExecution.withSQLConfPropagated(json.sparkSession) {
+ JsonInferSchema.infer(rdd, parsedOptions, rowParser)
+ }
}
private def createBaseDataset(
sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): Dataset[String] = {
- val paths = inputPaths.map(_.getPath.toString)
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): Dataset[String] = {
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
- paths = paths,
- className = classOf[TextFileFormat].getName
+ paths = inputPaths.map(_.getPath.toString),
+ className = classOf[TextFileFormat].getName,
+ options = parsedOptions.parameters
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}
@@ -120,10 +129,14 @@ object TextInputJsonDataSource extends JsonDataSource {
file: PartitionedFile,
parser: JacksonParser,
schema: StructType): Iterator[InternalRow] = {
- val linesReader = new HadoopFileLinesReader(file, conf)
+ val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
+ val textParser = parser.options.encoding
+ .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text))
+ .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text))
+
val safeParser = new FailureSafeParser[Text](
- input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
+ input => parser.parse(input, textParser, textToUTF8String),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
@@ -144,16 +157,24 @@ object MultiLineJsonDataSource extends JsonDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
- val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
+ val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
- JsonInferSchema.infer(sampled, parsedOptions, createParser)
+ val parser = parsedOptions.encoding
+ .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
+ .getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
+
+ SQLExecution.withSQLConfPropagated(sparkSession) {
+ JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
+ }
}
private def createBaseRdd(
sparkSession: SparkSession,
- inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
+ inputPaths: Seq[FileStatus],
+ parsedOptions: JSONOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
- val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+ val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
+ parsedOptions.parameters))
val conf = job.getConfiguration
val name = paths.mkString(",")
FileInputFormat.setInputPaths(job, paths: _*)
@@ -168,11 +189,18 @@ object MultiLineJsonDataSource extends JsonDataSource {
.values
}
- private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
- val path = new Path(record.getPath())
- CreateJacksonParser.inputStream(
- jsonFactory,
- CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path))
+ private def dataToInputStream(dataStream: PortableDataStream): InputStream = {
+ val path = new Path(dataStream.getPath())
+ CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path)
+ }
+
+ private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream))
+ }
+
+ private def createParser(enc: String, jsonFactory: JsonFactory,
+ stream: PortableDataStream): JsonParser = {
+ CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream))
}
override def readFile(
@@ -187,9 +215,12 @@ object MultiLineJsonDataSource extends JsonDataSource {
UTF8String.fromBytes(ByteStreams.toByteArray(inputStream))
}
}
+ val streamParser = parser.options.encoding
+ .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream))
+ .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream))
val safeParser = new FailureSafeParser[InputStream](
- input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString),
+ input => parser.parse[InputStream](input, streamParser, partitionedFileString),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 0862c746fffad..3b04510d29695 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.json
+import java.nio.charset.{Charset, StandardCharsets}
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -151,7 +153,13 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with Logging {
- private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
+ private val encoding = options.encoding match {
+ case Some(charsetName) => Charset.forName(charsetName)
+ case None => StandardCharsets.UTF_8
+ }
+
+ private val writer = CodecStreams.createOutputStreamWriter(
+ context, new Path(path), encoding)
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index a270a6451d5dd..f6edc7bfb3750 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -45,8 +45,9 @@ private[sql] object JsonInferSchema {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
- // perform schema inference on each row and merge afterwards
- val rootType = json.mapPartitions { iter =>
+ // In each RDD partition, perform schema inference on each row and merge afterwards.
+ val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
+ val mergedTypesFromPartitions = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
@@ -66,11 +67,15 @@ private[sql] object JsonInferSchema {
s"Parse Mode: ${FailFastMode.name}.", e)
}
}
- }
- }.fold(StructType(Nil))(
- compatibleRootType(columnNameOfCorruptRecord, parseMode))
+ }.reduceOption(typeMerger).toIterator
+ }
+
+ // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because
+ // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have
+ // active SparkSession and `SQLConf.get` may point to the wrong configs.
+ val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
- canonicalizeType(rootType) match {
+ canonicalizeType(rootType, configOptions) match {
case Some(st: StructType) => st
case _ =>
// canonicalizeType erases all empty structs, including the only one we want to keep
@@ -176,33 +181,33 @@ private[sql] object JsonInferSchema {
}
/**
- * Convert NullType to StringType and remove StructTypes with no fields
+ * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
+ * drops NullTypes or converts them to StringType based on provided options.
*/
- private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
- case at @ ArrayType(elementType, _) =>
- for {
- canonicalType <- canonicalizeType(elementType)
- } yield {
- at.copy(canonicalType)
- }
+ private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
+ case at: ArrayType =>
+ canonicalizeType(at.elementType, options)
+ .map(t => at.copy(elementType = t))
case StructType(fields) =>
- val canonicalFields: Array[StructField] = for {
- field <- fields
- if field.name.length > 0
- canonicalType <- canonicalizeType(field.dataType)
- } yield {
- field.copy(dataType = canonicalType)
+ val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
+ canonicalizeType(f.dataType, options)
+ .map(t => f.copy(dataType = t))
}
-
- if (canonicalFields.length > 0) {
- Some(StructType(canonicalFields))
+ // SPARK-8093: empty structs should be deleted
+ if (canonicalFields.isEmpty) {
+ None
} else {
- // per SPARK-8093: empty structs should be deleted
+ Some(StructType(canonicalFields))
+ }
+
+ case NullType =>
+ if (options.dropFieldIfAllNull) {
None
+ } else {
+ Some(StringType)
}
- case NullType => Some(StringType)
case other => Some(other)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index dbf3bc6f0ee6c..1de2ca2914c44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -188,6 +188,12 @@ class OrcFileFormat
if (enableVectorizedReader) {
val batchReader = new OrcColumnarBatchReader(
enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity)
+ // SPARK-23399 Register a task completion listener first to call `close()` in all cases.
+ // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM)
+ // after opening a file.
+ val iter = new RecordReaderIterator(batchReader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+
batchReader.initialize(fileSplit, taskAttemptContext)
batchReader.initBatch(
reader.getSchema,
@@ -196,8 +202,6 @@ class OrcFileFormat
partitionSchema,
file.partitionValues)
- val iter = new RecordReaderIterator(batchReader)
- Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
iter.asInstanceOf[Iterator[InternalRow]]
} else {
val orcRecordReader = new OrcInputFormat[OrcStruct]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index ba69f9a26c968..60fc9ec7e1f82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -34,6 +34,7 @@ import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
import org.apache.parquet.hadoop._
+import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel
import org.apache.parquet.hadoop.codec.CodecConfig
import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.schema.MessageType
@@ -125,16 +126,17 @@ class ParquetFileFormat
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName)
// SPARK-15719: Disables writing Parquet summary files by default.
- if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) {
- conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false)
+ if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null
+ && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) {
+ conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE)
}
- if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false)
+ if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE
&& !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) {
// output summary is requested, but the class is not a Parquet Committer
logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" +
s" create job summaries. " +
- s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.")
+ s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.")
}
new OutputWriterFactory {
@@ -321,19 +323,6 @@ class ParquetFileFormat
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
- // Try to push down filters when filter push-down is enabled.
- val pushed =
- if (sparkSession.sessionState.conf.parquetFilterPushDown) {
- filters
- // Collects all converted Parquet filter predicates. Notice that not all predicates can be
- // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
- // is used here.
- .flatMap(ParquetFilters.createFilter(requiredSchema, _))
- .reduceOption(FilterApi.and)
- } else {
- None
- }
-
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
@@ -351,12 +340,27 @@ class ParquetFileFormat
val timestampConversion: Boolean =
sparkSession.sessionState.conf.isParquetINT96TimestampConversion
val capacity = sqlConf.parquetVectorizedReaderBatchSize
+ val enableParquetFilterPushDown: Boolean =
+ sparkSession.sessionState.conf.parquetFilterPushDown
// Whole stage codegen (PhysicalRDD) is able to deal with batches directly
val returningBatch = supportBatch(sparkSession, resultSchema)
+ val pushDownDate = sqlConf.parquetFilterPushDownDate
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
+ // Try to push down filters when filter push-down is enabled.
+ val pushed = if (enableParquetFilterPushDown) {
+ filters
+ // Collects all converted Parquet filter predicates. Notice that not all predicates can be
+ // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
+ // is used here.
+ .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _))
+ .reduceOption(FilterApi.and)
+ } else {
+ None
+ }
+
val fileSplit =
new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty)
@@ -395,16 +399,21 @@ class ParquetFileFormat
ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
}
val taskContext = Option(TaskContext.get())
- val parquetReader = if (enableVectorizedReader) {
+ if (enableVectorizedReader) {
val vectorizedReader = new VectorizedParquetRecordReader(
convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity)
+ val iter = new RecordReaderIterator(vectorizedReader)
+ // SPARK-23457 Register a task completion lister before `initialization`.
+ taskContext.foreach(_.addTaskCompletionListener(_ => iter.close()))
vectorizedReader.initialize(split, hadoopAttemptContext)
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
if (returningBatch) {
vectorizedReader.enableReturningBatches()
}
- vectorizedReader
+
+ // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
+ iter.asInstanceOf[Iterator[InternalRow]]
} else {
logDebug(s"Falling back to parquet-mr")
// ParquetRecordReader returns UnsafeRow
@@ -414,18 +423,11 @@ class ParquetFileFormat
} else {
new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz))
}
+ val iter = new RecordReaderIterator(reader)
+ // SPARK-23457 Register a task completion lister before `initialization`.
+ taskContext.foreach(_.addTaskCompletionListener(_ => iter.close()))
reader.initialize(split, hadoopAttemptContext)
- reader
- }
- val iter = new RecordReaderIterator(parquetReader)
- taskContext.foreach(_.addTaskCompletionListener(_ => iter.close()))
-
- // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
- if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
- enableVectorizedReader) {
- iter.asInstanceOf[Iterator[InternalRow]]
- } else {
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 763841efbd9f3..310626197a763 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -17,17 +17,25 @@
package org.apache.spark.sql.execution.datasources.parquet
+import java.sql.Date
+
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.io.api.Binary
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
/**
* Some utility function to convert Spark data source filters to Parquet filters.
*/
-private[parquet] object ParquetFilters {
+private[parquet] class ParquetFilters(pushDownDate: Boolean) {
+
+ private def dateToDays(date: Date): SQLDate = {
+ DateTimeUtils.fromJavaDate(date)
+ }
private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
@@ -50,6 +58,10 @@ private[parquet] object ParquetFilters {
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.eq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -72,6 +84,10 @@ private[parquet] object ParquetFilters {
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -91,6 +107,10 @@ private[parquet] object ParquetFilters {
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.lt(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -110,6 +130,10 @@ private[parquet] object ParquetFilters {
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.ltEq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -129,6 +153,10 @@ private[parquet] object ParquetFilters {
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.gt(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
@@ -148,6 +176,10 @@ private[parquet] object ParquetFilters {
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case DateType if pushDownDate =>
+ (n: String, v: Any) => FilterApi.gtEq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
index f36a89a4c3c5f..9cfc30725f03a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -81,7 +81,10 @@ object ParquetOptions {
"uncompressed" -> CompressionCodecName.UNCOMPRESSED,
"snappy" -> CompressionCodecName.SNAPPY,
"gzip" -> CompressionCodecName.GZIP,
- "lzo" -> CompressionCodecName.LZO)
+ "lzo" -> CompressionCodecName.LZO,
+ "lz4" -> CompressionCodecName.LZ4,
+ "brotli" -> CompressionCodecName.BROTLI,
+ "zstd" -> CompressionCodecName.ZSTD)
def getParquetCompressionCodecName(name: String): String = {
shortParquetCompressionCodecNames(name).name()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 5dbcf4a915cbf..cab00251622b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
@@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case _: ClassNotFoundException => u
case e: Exception =>
// the provider is valid, but failed to create a logical plan
- u.failAnalysis(e.getMessage)
+ u.failAnalysis(e.getMessage, e)
}
}
}
@@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " +
s"`${specifiedProvider.getSimpleName}`.")
}
+ tableDesc.storage.locationUri match {
+ case Some(location) if location.getPath != existingTable.location.getPath =>
+ throw new AnalysisException(
+ s"The location of the existing table ${tableIdentWithDB.quotedString} is " +
+ s"`${existingTable.location}`. It doesn't match the specified location " +
+ s"`${tableDesc.location}`.")
+ case _ =>
+ }
if (query.schema.length != existingTable.schema.length) {
throw new AnalysisException(
@@ -178,7 +186,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
c.copy(
tableDesc = existingTable,
- query = Some(newQuery))
+ query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput(
+ newQuery, existingTable.schema.toAttributes, conf)))
// Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity
// config, and do various checks:
@@ -316,7 +325,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
* table. It also does data type casting and field renaming, to make sure that the columns to be
* inserted have the correct data type and fields have the correct names.
*/
-case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
+case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
private def preprocess(
insert: InsertIntoTable,
tblName: String,
@@ -336,6 +345,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit
s"including ${staticPartCols.size} partition column(s) having constant value(s).")
}
+ val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput(
+ insert.query, expectedColumns, conf)
if (normalizedPartSpec.nonEmpty) {
if (normalizedPartSpec.size != partColNames.length) {
throw new AnalysisException(
@@ -346,37 +357,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit
""".stripMargin)
}
- castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns)
+ insert.copy(query = newQuery, partition = normalizedPartSpec)
} else {
// All partition columns are dynamic because the InsertIntoTable command does
// not explicitly specify partitioning columns.
- castAndRenameChildOutput(insert, expectedColumns)
- .copy(partition = partColNames.map(_ -> None).toMap)
- }
- }
-
- private def castAndRenameChildOutput(
- insert: InsertIntoTable,
- expectedOutput: Seq[Attribute]): InsertIntoTable = {
- val newChildOutput = expectedOutput.zip(insert.query.output).map {
- case (expected, actual) =>
- if (expected.dataType.sameType(actual.dataType) &&
- expected.name == actual.name &&
- expected.metadata == actual.metadata) {
- actual
- } else {
- // Renaming is needed for handling the following cases like
- // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
- // 2) Target tables have column metadata
- Alias(cast(actual, expected.dataType), expected.name)(
- explicitMetadata = Option(expected.metadata))
- }
- }
-
- if (newChildOutput == insert.query.output) {
- insert
- } else {
- insert.copy(query = Project(newChildOutput, insert.query))
+ insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap)
}
}
@@ -491,3 +476,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) {
}
}
}
+
+object DDLPreprocessingUtils {
+
+ /**
+ * Adjusts the name and data type of the input query output columns, to match the expectation.
+ */
+ def castAndRenameQueryOutput(
+ query: LogicalPlan,
+ expectedOutput: Seq[Attribute],
+ conf: SQLConf): LogicalPlan = {
+ val newChildOutput = expectedOutput.zip(query.output).map {
+ case (expected, actual) =>
+ if (expected.dataType.sameType(actual.dataType) &&
+ expected.name == actual.name &&
+ expected.metadata == actual.metadata) {
+ actual
+ } else {
+ // Renaming is needed for handling the following cases like
+ // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
+ // 2) Target tables have column metadata
+ Alias(
+ Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)),
+ expected.name)(explicitMetadata = Option(expected.metadata))
+ }
+ }
+
+ if (newChildOutput == query.output) {
+ query
+ } else {
+ Project(newChildOutput, query)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index c661e9bd3b94c..e93908da43535 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -17,11 +17,8 @@
package org.apache.spark.sql.execution.datasources.text
-import java.io.Closeable
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.TaskContext
@@ -29,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
@@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
- new TextOutputWriter(path, dataSchema, context)
+ new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context)
}
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
- readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText)
+ readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions)
}
private def readToUnsafeMem(
conf: Broadcast[SerializableConfiguration],
requiredSchema: StructType,
- wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = {
+ textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = {
(file: PartitionedFile) => {
val confValue = conf.value.value
- val reader = if (!wholeTextMode) {
- new HadoopFileLinesReader(file, confValue)
+ val reader = if (!textOptions.wholeText) {
+ new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue)
} else {
new HadoopFileWholeTextReader(file, confValue)
}
@@ -133,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
val emptyUnsafeRow = new UnsafeRow(0)
reader.map(_ => emptyUnsafeRow)
} else {
- val unsafeRow = new UnsafeRow(1)
- val bufferHolder = new BufferHolder(unsafeRow)
- val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
+ val unsafeRowWriter = new UnsafeRowWriter(1)
reader.map { line =>
// Writes to an UnsafeRow directly
- bufferHolder.reset()
+ unsafeRowWriter.reset()
unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
- unsafeRow.setTotalSize(bufferHolder.totalSize())
- unsafeRow
+ unsafeRowWriter.getRow()
}
}
}
@@ -152,6 +146,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
class TextOutputWriter(
path: String,
dataSchema: StructType,
+ lineSeparator: Array[Byte],
context: TaskAttemptContext)
extends OutputWriter {
@@ -162,7 +157,7 @@ class TextOutputWriter(
val utf8string = row.getUTF8String(0)
utf8string.writeTo(writer)
}
- writer.write('\n')
+ writer.write(lineSeparator)
}
override def close(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
index 2a661561ab51e..e4e201995faa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.text
+import java.nio.charset.{Charset, StandardCharsets}
+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
/**
@@ -39,9 +41,25 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti
*/
val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean
+ val encoding: Option[String] = parameters.get(ENCODING)
+
+ val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep =>
+ require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.")
+
+ lineSep
+ }
+
+ // Note that the option 'lineSep' uses a different default value in read and write.
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8))
+ }
+ val lineSeparatorInWrite: Array[Byte] =
+ lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8))
}
-private[text] object TextOptions {
+private[datasources] object TextOptions {
val COMPRESSION = "compression"
val WHOLETEXT = "wholetext"
+ val ENCODING = "encoding"
+ val LINE_SEPARATOR = "lineSep"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 5ed0ba71e94c7..8d6fb3820d420 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -22,24 +22,25 @@ import scala.reflect.ClassTag
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.InputPartition
-class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
+class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
extends Partition with Serializable
class DataSourceRDD[T: ClassTag](
sc: SparkContext,
- @transient private val readerFactories: java.util.List[DataReaderFactory[T]])
+ @transient private val inputPartitions: Seq[InputPartition[T]])
extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
- readerFactories.asScala.zipWithIndex.map {
- case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
+ inputPartitions.zipWithIndex.map {
+ case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition)
}.toArray
}
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
+ val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
+ .createPartitionReader()
context.addTaskCompletionListener(_ => reader.close())
val iter = new Iterator[T] {
private[this] var valuePrepared = false
@@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
+ split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
deleted file mode 100644
index 81219e9771bd8..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import java.util.Objects
-
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.sources.v2.reader._
-
-/**
- * A base class for data source reader holder with customized equals/hashCode methods.
- */
-trait DataSourceReaderHolder {
-
- /**
- * The output of the data source reader, w.r.t. column pruning.
- */
- def output: Seq[Attribute]
-
- /**
- * The held data source reader.
- */
- def reader: DataSourceReader
-
- /**
- * The metadata of this data source reader that can be used for equality test.
- */
- private def metadata: Seq[Any] = {
- val filters: Any = reader match {
- case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet
- case s: SupportsPushDownFilters => s.pushedFilters().toSet
- case _ => Nil
- }
- Seq(output, reader.getClass, filters)
- }
-
- def canEqual(other: Any): Boolean
-
- override def equals(other: Any): Boolean = other match {
- case other: DataSourceReaderHolder =>
- canEqual(other) && metadata.length == other.metadata.length &&
- metadata.zip(other.metadata).forall { case (l, r) => l == r }
- case _ => false
- }
-
- override def hashCode(): Int = {
- metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 38f6b15224788..7613eb210c659 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -17,19 +17,41 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
-import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics}
+import org.apache.spark.sql.types.StructType
+/**
+ * A logical plan representing a data source v2 scan.
+ *
+ * @param source An instance of a [[DataSourceV2]] implementation.
+ * @param options The options for this scan. Used to create fresh [[DataSourceReader]].
+ * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh
+ * [[DataSourceReader]].
+ */
case class DataSourceV2Relation(
+ source: DataSourceV2,
output: Seq[AttributeReference],
- reader: DataSourceReader)
- extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
+ options: Map[String, String],
+ userSpecifiedSchema: Option[StructType])
+ extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {
- override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]
+ import DataSourceV2Relation._
- override def computeStats(): Statistics = reader match {
+ override def pushedFilters: Seq[Expression] = Seq.empty
+
+ override def simpleString: String = "RelationV2 " + metadataString
+
+ def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema)
+
+ override def computeStats(): Statistics = newReader match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
case _ =>
@@ -42,17 +64,101 @@ case class DataSourceV2Relation(
}
/**
- * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
- * to the non-streaming relation.
+ * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true.
+ *
+ * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan,
+ * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]]
+ * after we figure out how to apply operator push-down for streaming data sources.
*/
-class StreamingDataSourceV2Relation(
+case class StreamingDataSourceV2Relation(
output: Seq[AttributeReference],
- reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
+ source: DataSourceV2,
+ options: Map[String, String],
+ reader: DataSourceReader)
+ extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat {
+
override def isStreaming: Boolean = true
+
+ override def simpleString: String = "Streaming RelationV2 " + metadataString
+
+ override def pushedFilters: Seq[Expression] = Nil
+
+ override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))
+
+ // TODO: unify the equal/hashCode implementation for all data source v2 query plans.
+ override def equals(other: Any): Boolean = other match {
+ case other: StreamingDataSourceV2Relation =>
+ output == other.output && reader.getClass == other.reader.getClass && options == other.options
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Seq(output, source, options).hashCode()
+ }
+
+ override def computeStats(): Statistics = reader match {
+ case r: SupportsReportStatistics =>
+ Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
+ case _ =>
+ Statistics(sizeInBytes = conf.defaultSizeInBytes)
+ }
}
object DataSourceV2Relation {
- def apply(reader: DataSourceReader): DataSourceV2Relation = {
- new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
+ private implicit class SourceHelpers(source: DataSourceV2) {
+ def asReadSupport: ReadSupport = {
+ source match {
+ case support: ReadSupport =>
+ support
+ case _: ReadSupportWithSchema =>
+ // this method is only called if there is no user-supplied schema. if there is no
+ // user-supplied schema and ReadSupport was not implemented, throw a helpful exception.
+ throw new AnalysisException(s"Data source requires a user-supplied schema: $name")
+ case _ =>
+ throw new AnalysisException(s"Data source is not readable: $name")
+ }
+ }
+
+ def asReadSupportWithSchema: ReadSupportWithSchema = {
+ source match {
+ case support: ReadSupportWithSchema =>
+ support
+ case _: ReadSupport =>
+ throw new AnalysisException(
+ s"Data source does not support user-supplied schema: $name")
+ case _ =>
+ throw new AnalysisException(s"Data source is not readable: $name")
+ }
+ }
+
+ def name: String = {
+ source match {
+ case registered: DataSourceRegister =>
+ registered.shortName()
+ case _ =>
+ source.getClass.getSimpleName
+ }
+ }
+
+ def createReader(
+ options: Map[String, String],
+ userSpecifiedSchema: Option[StructType]): DataSourceReader = {
+ val v2Options = new DataSourceOptions(options.asJava)
+ userSpecifiedSchema match {
+ case Some(s) =>
+ asReadSupportWithSchema.createReader(s, v2Options)
+ case _ =>
+ asReadSupport.createReader(v2Options)
+ }
+ }
+ }
+
+ def create(
+ source: DataSourceV2,
+ options: Map[String, String],
+ userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = {
+ val reader = source.createReader(options, userSpecifiedSchema)
+ DataSourceV2Relation(
+ source, reader.readSchema().toAttributes, options, userSpecifiedSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 7d9581be4db89..c6a7684bf6ab0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -25,23 +25,49 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* Physical plan node for scanning data from a data source.
*/
case class DataSourceV2ScanExec(
output: Seq[AttributeReference],
+ @transient source: DataSourceV2,
+ @transient options: Map[String, String],
+ @transient pushedFilters: Seq[Expression],
@transient reader: DataSourceReader)
- extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
+ extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan {
- override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]
+ override def simpleString: String = "ScanV2 " + metadataString
+
+ // TODO: unify the equal/hashCode implementation for all data source v2 query plans.
+ override def equals(other: Any): Boolean = other match {
+ case other: DataSourceV2ScanExec =>
+ output == other.output && reader.getClass == other.reader.getClass && options == other.options
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Seq(output, source, options).hashCode()
+ }
override def outputPartitioning: physical.Partitioning = reader match {
+ case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 =>
+ SinglePartition
+
+ case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 =>
+ SinglePartition
+
+ case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 =>
+ SinglePartition
+
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
@@ -49,31 +75,38 @@ case class DataSourceV2ScanExec(
case _ => super.outputPartitioning
}
- private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
- case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
+ private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
+ case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
case _ =>
- reader.createDataReaderFactories().asScala.map {
- new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
- }.asJava
+ reader.planInputPartitions().asScala.map {
+ new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
+ }
}
- private lazy val inputRDD: RDD[InternalRow] = reader match {
+ private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match {
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
assert(!reader.isInstanceOf[ContinuousReader],
"continuous stream reader does not support columnar read yet.")
- new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
- .asInstanceOf[RDD[InternalRow]]
+ r.planBatchInputPartitions().asScala
+ }
+ private lazy val inputRDD: RDD[InternalRow] = reader match {
case _: ContinuousReader =>
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
- .askSync[Unit](SetReaderPartitions(readerFactories.size()))
- new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
- .asInstanceOf[RDD[InternalRow]]
+ .askSync[Unit](SetReaderPartitions(partitions.size))
+ new ContinuousDataSourceRDD(
+ sparkContext,
+ sqlContext.conf.continuousStreamingExecutorQueueSize,
+ sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
+ partitions).asInstanceOf[RDD[InternalRow]]
+
+ case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
+ new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]]
case _ =>
- new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
+ new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]]
}
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
@@ -98,19 +131,22 @@ case class DataSourceV2ScanExec(
}
}
-class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
- extends DataReaderFactory[UnsafeRow] {
+class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
+ extends InputPartition[UnsafeRow] {
- override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
+ override def preferredLocations: Array[String] = partition.preferredLocations
- override def createDataReader: DataReader[UnsafeRow] = {
- new RowToUnsafeDataReader(
- rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
+ override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
+ new RowToUnsafeInputPartitionReader(
+ partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
}
-class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
- extends DataReader[UnsafeRow] {
+class RowToUnsafeInputPartitionReader(
+ val rowReader: InputPartitionReader[Row],
+ encoder: ExpressionEncoder[Row])
+
+ extends InputPartitionReader[UnsafeRow] {
override def next: Boolean = rowReader.next
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index df5b524485f54..182aa2906cf1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -17,18 +17,130 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.Strategy
+import scala.collection.mutable
+
+import org.apache.spark.sql.{sources, Strategy}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
object DataSourceV2Strategy extends Strategy {
+
+ /**
+ * Pushes down filters to the data source reader
+ *
+ * @return pushed filter and post-scan filters.
+ */
+ private def pushFilters(
+ reader: DataSourceReader,
+ filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ reader match {
+ case r: SupportsPushDownCatalystFilters =>
+ val postScanFilters = r.pushCatalystFilters(filters.toArray)
+ val pushedFilters = r.pushedCatalystFilters()
+ (pushedFilters, postScanFilters)
+
+ case r: SupportsPushDownFilters =>
+ // A map from translated data source filters to original catalyst filter expressions.
+ val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression]
+ // Catalyst filter expression that can't be translated to data source filters.
+ val untranslatableExprs = mutable.ArrayBuffer.empty[Expression]
+
+ for (filterExpr <- filters) {
+ val translated = DataSourceStrategy.translateFilter(filterExpr)
+ if (translated.isDefined) {
+ translatedFilterToExpr(translated.get) = filterExpr
+ } else {
+ untranslatableExprs += filterExpr
+ }
+ }
+
+ // Data source filters that need to be evaluated again after scanning. which means
+ // the data source cannot guarantee the rows returned can pass these filters.
+ // As a result we must return it so Spark can plan an extra filter operator.
+ val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray)
+ .map(translatedFilterToExpr)
+ // The filters which are marked as pushed to this data source
+ val pushedFilters = r.pushedFilters().map(translatedFilterToExpr)
+ (pushedFilters, untranslatableExprs ++ postScanFilters)
+
+ case _ => (Nil, filters)
+ }
+ }
+
+ /**
+ * Applies column pruning to the data source, w.r.t. the references of the given expressions.
+ *
+ * @return new output attributes after column pruning.
+ */
+ // TODO: nested column pruning.
+ private def pruneColumns(
+ reader: DataSourceReader,
+ relation: DataSourceV2Relation,
+ exprs: Seq[Expression]): Seq[AttributeReference] = {
+ reader match {
+ case r: SupportsPushDownRequiredColumns =>
+ val requiredColumns = AttributeSet(exprs.flatMap(_.references))
+ val neededOutput = relation.output.filter(requiredColumns.contains)
+ if (neededOutput != relation.output) {
+ r.pruneColumns(neededOutput.toStructType)
+ val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
+ r.readSchema().toAttributes.map {
+ // We have to keep the attribute id during transformation.
+ a => a.withExprId(nameToAttr(a.name).exprId)
+ }
+ } else {
+ relation.output
+ }
+
+ case _ => relation.output
+ }
+ }
+
+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case DataSourceV2Relation(output, reader) =>
- DataSourceV2ScanExec(output, reader) :: Nil
+ case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
+ val reader = relation.newReader()
+ // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
+ // `postScanFilters` need to be evaluated after the scan.
+ // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
+ val (pushedFilters, postScanFilters) = pushFilters(reader, filters)
+ val output = pruneColumns(reader, relation, project ++ postScanFilters)
+ logInfo(
+ s"""
+ |Pushing operators to ${relation.source.getClass}
+ |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Post-Scan Filters: ${postScanFilters.mkString(",")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
+
+ val scan = DataSourceV2ScanExec(
+ output, relation.source, relation.options, pushedFilters, reader)
+
+ val filterCondition = postScanFilters.reduceLeftOption(And)
+ val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)
+
+ val withProjection = if (withFilter.output != project) {
+ ProjectExec(project, withFilter)
+ } else {
+ withFilter
+ }
+
+ withProjection :: Nil
+
+ case r: StreamingDataSourceV2Relation =>
+ DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil
case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
+ case WriteToContinuousDataSource(writer, query) =>
+ WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
new file mode 100644
index 0000000000000..97e6c6d702acb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.DataSourceV2
+import org.apache.spark.util.Utils
+
+/**
+ * A trait that can be used by data source v2 related query plans(both logical and physical), to
+ * provide a string format of the data source information for explain.
+ */
+trait DataSourceV2StringFormat {
+
+ /**
+ * The instance of this data source implementation. Note that we only consider its class in
+ * equals/hashCode, not the instance itself.
+ */
+ def source: DataSourceV2
+
+ /**
+ * The output of the data source reader, w.r.t. column pruning.
+ */
+ def output: Seq[Attribute]
+
+ /**
+ * The options for this data source reader.
+ */
+ def options: Map[String, String]
+
+ /**
+ * The filters which have been pushed to the data source.
+ */
+ def pushedFilters: Seq[Expression]
+
+ private def sourceName: String = source match {
+ case registered: DataSourceRegister => registered.shortName()
+ // source.getClass.getSimpleName can cause Malformed class name error,
+ // call safer `Utils.getSimpleName` instead
+ case _ => Utils.getSimpleName(source.getClass)
+ }
+
+ def metadataString: String = {
+ val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)]
+
+ if (pushedFilters.nonEmpty) {
+ entries += "Filters" -> pushedFilters.mkString("[", ", ", "]")
+ }
+
+ // TODO: we should only display some standard options like path, table, etc.
+ if (options.nonEmpty) {
+ entries += "Options" -> Utils.redact(options).map {
+ case (k, v) => s"$k=$v"
+ }.mkString("[", ",", "]")
+ }
+
+ val outputStr = Utils.truncatedString(output, "[", ", ", "]")
+
+ val entriesStr = if (entries.nonEmpty) {
+ Utils.truncatedString(entries.map {
+ case (key, value) => key + ": " + StringUtils.abbreviate(value, 100)
+ }, " (", ", ", ")")
+ } else {
+ ""
+ }
+
+ s"$sourceName$outputStr$entriesStr"
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
deleted file mode 100644
index 1ca6cbf061b4e..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper}
-import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-import org.apache.spark.sql.sources
-import org.apache.spark.sql.sources.v2.reader._
-
-/**
- * Pushes down various operators to the underlying data source for better performance. Operators are
- * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you
- * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the
- * data source should execute FILTER before LIMIT. And required columns are calculated at the end,
- * because when more operators are pushed down, we may need less columns at Spark side.
- */
-object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper {
- override def apply(plan: LogicalPlan): LogicalPlan = {
- // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may
- // appear in many places for column pruning.
- // TODO: Ideally column pruning should be implemented via a plan property that is propagated
- // top-down, then we can simplify the logic here and only collect target operators.
- val filterPushed = plan transformUp {
- case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) =>
- val (candidates, nonDeterministic) =
- splitConjunctivePredicates(condition).partition(_.deterministic)
-
- val stayUpFilters: Seq[Expression] = reader match {
- case r: SupportsPushDownCatalystFilters =>
- r.pushCatalystFilters(candidates.toArray)
-
- case r: SupportsPushDownFilters =>
- // A map from original Catalyst expressions to corresponding translated data source
- // filters. If a predicate is not in this map, it means it cannot be pushed down.
- val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p =>
- DataSourceStrategy.translateFilter(p).map(f => p -> f)
- }.toMap
-
- // Catalyst predicate expressions that cannot be converted to data source filters.
- val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains)
-
- // Data source filters that cannot be pushed down. An unhandled filter means
- // the data source cannot guarantee the rows returned can pass the filter.
- // As a result we must return it so Spark can plan an extra filter operator.
- val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet
- val unhandledPredicates = translatedMap.filter { case (_, f) =>
- unhandledFilters.contains(f)
- }.keys
-
- nonConvertiblePredicates ++ unhandledPredicates
-
- case _ => candidates
- }
-
- val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And)
- val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r)
- if (withFilter.output == fields) {
- withFilter
- } else {
- Project(fields, withFilter)
- }
- }
-
- // TODO: add more push down rules.
-
- val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
- // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
- RemoveRedundantProject(columnPruned)
- }
-
- // TODO: nested fields pruning
- private def pushDownRequiredColumns(
- plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = {
- plan match {
- case p @ Project(projectList, child) =>
- val required = projectList.flatMap(_.references)
- p.copy(child = pushDownRequiredColumns(child, AttributeSet(required)))
-
- case f @ Filter(condition, child) =>
- val required = requiredByParent ++ condition.references
- f.copy(child = pushDownRequiredColumns(child, required))
-
- case relation: DataSourceV2Relation => relation.reader match {
- case reader: SupportsPushDownRequiredColumns =>
- // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now
- // it's possible that the mutable reader being updated by someone else, and we need to
- // always call `reader.pruneColumns` here to correct it.
- // assert(relation.output.toStructType == reader.readSchema(),
- // "Schema of data source reader does not match the relation plan.")
-
- val requiredColumns = relation.output.filter(requiredByParent.contains)
- reader.pruneColumns(requiredColumns.toStructType)
-
- val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
- val newOutput = reader.readSchema().map(_.name).map(nameToAttr)
- relation.copy(output = newOutput)
-
- case _ => relation
- }
-
- // TODO: there may be more operators that can be used to calculate the required columns. We
- // can add more and more in the future.
- case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet))
- }
- }
-
- /**
- * Finds a Filter node(with an optional Project child) above data source relation.
- */
- object FilterAndProject {
- // returns the project list, the filter condition and the data source relation.
- def unapply(plan: LogicalPlan)
- : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match {
-
- case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r))
-
- case Filter(condition, Project(fields, r: DataSourceV2Relation))
- if fields.forall(_.deterministic) =>
- val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e))
- val substituted = condition.transform {
- case a: Attribute => attributeMap.getOrElse(a, a)
- }
- Some((fields, substituted, r))
-
- case _ => None
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index eefbcf4c0e087..ea4bda327f36f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.util.control.NonFatal
+
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
+import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
@@ -26,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution}
import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -53,6 +57,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}
+ val useCommitCoordinator = writer.useCommitCoordinator
val rdd = query.execute()
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
@@ -60,25 +65,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
s"The input RDD has ${messages.length} partitions.")
try {
- val runTask = writer match {
- // This case means that we're doing continuous processing. In microbatch streaming, the
- // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch.
- case w: StreamWriter =>
- EpochCoordinatorRef.get(
- sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
- sparkContext.env)
- .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
-
- (context: TaskContext, iter: Iterator[InternalRow]) =>
- DataWritingSparkTask.runContinuous(writeTask, context, iter)
- case _ =>
- (context: TaskContext, iter: Iterator[InternalRow]) =>
- DataWritingSparkTask.run(writeTask, context, iter)
- }
-
sparkContext.runJob(
rdd,
- runTask,
+ (context: TaskContext, iter: Iterator[InternalRow]) =>
+ DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator),
rdd.partitions.indices,
(index, message: WriterCommitMessage) => {
messages(index) = message
@@ -86,14 +76,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
}
)
- if (!writer.isInstanceOf[StreamWriter]) {
- logInfo(s"Data source writer $writer is committing.")
- writer.commit(messages)
- logInfo(s"Data source writer $writer committed.")
- }
+ logInfo(s"Data source writer $writer is committing.")
+ writer.commit(messages)
+ logInfo(s"Data source writer $writer committed.")
} catch {
- case _: InterruptedException if writer.isInstanceOf[StreamWriter] =>
- // Interruption is how continuous queries are ended, so accept and ignore the exception.
case cause: Throwable =>
logError(s"Data source writer $writer is aborting.")
try {
@@ -105,7 +91,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
throw new SparkException("Writing job failed.", cause)
}
logError(s"Data source writer $writer aborted.")
- throw new SparkException("Writing job aborted.", cause)
+ cause match {
+ // Only wrap non fatal exceptions.
+ case NonFatal(e) => throw new SparkException("Writing job aborted.", e)
+ case _ => throw cause
+ }
}
sparkContext.emptyRDD
@@ -116,70 +106,61 @@ object DataWritingSparkTask extends Logging {
def run(
writeTask: DataWriterFactory[InternalRow],
context: TaskContext,
- iter: Iterator[InternalRow]): WriterCommitMessage = {
- val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
+ iter: Iterator[InternalRow],
+ useCommitCoordinator: Boolean): WriterCommitMessage = {
+ val stageId = context.stageId()
+ val partId = context.partitionId()
+ val attemptId = context.attemptNumber()
+ val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
+ val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
- iter.foreach(dataWriter.write)
- logInfo(s"Writer for partition ${context.partitionId()} is committing.")
- val msg = dataWriter.commit()
- logInfo(s"Writer for partition ${context.partitionId()} committed.")
+ while (iter.hasNext) {
+ dataWriter.write(iter.next())
+ }
+
+ val msg = if (useCommitCoordinator) {
+ val coordinator = SparkEnv.get.outputCommitCoordinator
+ val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId)
+ if (commitAuthorized) {
+ logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.")
+ dataWriter.commit()
+ } else {
+ val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit"
+ logInfo(message)
+ // throwing CommitDeniedException will trigger the catch block for abort
+ throw new CommitDeniedException(message, stageId, partId, attemptId)
+ }
+
+ } else {
+ logInfo(s"Writer for partition ${context.partitionId()} is committing.")
+ dataWriter.commit()
+ }
+
+ logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.")
+
msg
+
})(catchBlock = {
// If there is an error, abort this writer
- logError(s"Writer for partition ${context.partitionId()} is aborting.")
+ logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.")
dataWriter.abort()
- logError(s"Writer for partition ${context.partitionId()} aborted.")
+ logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.")
})
}
-
- def runContinuous(
- writeTask: DataWriterFactory[InternalRow],
- context: TaskContext,
- iter: Iterator[InternalRow]): WriterCommitMessage = {
- val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber())
- val epochCoordinator = EpochCoordinatorRef.get(
- context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
- SparkEnv.get)
- val currentMsg: WriterCommitMessage = null
- var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- do {
- // write the data and commit this writer.
- Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
- try {
- iter.foreach(dataWriter.write)
- logInfo(s"Writer for partition ${context.partitionId()} is committing.")
- val msg = dataWriter.commit()
- logInfo(s"Writer for partition ${context.partitionId()} committed.")
- epochCoordinator.send(
- CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
- )
- currentEpoch += 1
- } catch {
- case _: InterruptedException =>
- // Continuous shutdown always involves an interrupt. Just finish the task.
- }
- })(catchBlock = {
- // If there is an error, abort this writer
- logError(s"Writer for partition ${context.partitionId()} is aborting.")
- dataWriter.abort()
- logError(s"Writer for partition ${context.partitionId()} aborted.")
- })
- } while (!context.isInterrupted())
-
- currentMsg
- }
}
class InternalRowDataWriterFactory(
rowWriterFactory: DataWriterFactory[Row],
schema: StructType) extends DataWriterFactory[InternalRow] {
- override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): DataWriter[InternalRow] = {
new InternalRowDataWriter(
- rowWriterFactory.createDataWriter(partitionId, attemptNumber),
+ rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId),
RowEncoder.apply(schema).resolveAndBind())
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index daea6c39624d6..c55f9b8f1a7fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
+import scala.util.control.NonFatal
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.launcher.SparkLauncher
@@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.HashedRelation
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{SparkFatalException, ThreadUtils}
/**
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
@@ -69,7 +70,7 @@ case class BroadcastExchangeExec(
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
+ SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
@@ -111,12 +112,18 @@ case class BroadcastExchangeExec(
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
broadcasted
} catch {
+ // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw
+ // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult
+ // will catch this exception and re-throw the wrapped fatal throwable.
case oe: OutOfMemoryError =>
- throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
+ throw new SparkFatalException(
+ new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " +
s"all worker nodes. As a workaround, you can either disable broadcast by setting " +
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " +
s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value")
- .initCause(oe.getCause)
+ .initCause(oe.getCause))
+ case e if !NonFatal(e) =>
+ throw new SparkFatalException(e)
}
}
}(BroadcastExchangeExec.executionContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index e3d28388c5470..ad95879d86f42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.exchange
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
@@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
+ val pickedIndexes = mutable.Set[Int]()
+ val keysAndIndexes = currentOrderOfKeys.zipWithIndex
expectedOrderOfKeys.foreach(expression => {
- val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
+ val index = keysAndIndexes.find { case (e, idx) =>
+ // As we may have the same key used many times, we need to filter out its occurrence we
+ // have already used.
+ e.semanticEquals(expression) && !pickedIndexes.contains(idx)
+ }.map(_._2).get
+ pickedIndexes += index
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
@@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
- plan.transformUp {
+ plan match {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
@@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
+
+ case other => other
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index 78f11ca8d8c78..051e610eb2705 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -232,16 +232,16 @@ class ExchangeCoordinator(
// number of post-shuffle partitions.
val partitionStartIndices =
if (mapOutputStatistics.length == 0) {
- None
+ Array.empty[Int]
} else {
- Some(estimatePartitionStartIndices(mapOutputStatistics))
+ estimatePartitionStartIndices(mapOutputStatistics)
}
var k = 0
while (k < numExchanges) {
val exchange = exchanges(k)
val rdd =
- exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices)
+ exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices))
newPostShuffleRDDs.put(exchange, rdd)
k += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 4d95ee34f30de..b89203719541b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -153,12 +153,9 @@ object ShuffleExchangeExec {
* See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
*
* @param partitioner the partitioner for the shuffle
- * @param serializer the serializer that will be used to write rows
* @return true if rows should be copied before being shuffled, false otherwise
*/
- private def needToCopyObjectsBeforeShuffle(
- partitioner: Partitioner,
- serializer: Serializer): Boolean = {
+ private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = {
// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
@@ -167,22 +164,24 @@ object ShuffleExchangeExec {
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+ val numParts = partitioner.numPartitions
if (sortBasedShuffleOn) {
- val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
- if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+ if (numParts <= bypassMergeThreshold) {
// If we're using the original SortShuffleManager and the number of output partitions is
// sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
- } else if (serializer.supportsRelocationOfSerializedObjects) {
+ } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
// SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
// prior to sorting them. This optimization is only applied in cases where shuffle
// dependency does not specify an aggregator or ordering and the record serializer has
- // certain properties. If this optimization is enabled, we can safely avoid the copy.
+ // certain properties and the number of partitions doesn't exceed the limitation. If this
+ // optimization is enabled, we can safely avoid the copy.
//
- // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
- // need to check whether the optimization is enabled and supported by our serializer.
+ // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the
+ // serializer in Spark SQL always satisfy the properties, so we only need to check whether
+ // the number of partitions exceeds the limitation.
false
} else {
// Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
@@ -298,7 +297,7 @@ object ShuffleExchangeExec {
rdd
}
- if (needToCopyObjectsBeforeShuffle(part, serializer)) {
+ if (needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsInternal { iter =>
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 1918fcc5482db..0da0e8610c392 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -22,12 +22,13 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.LongType
+import org.apache.spark.sql.types.{BooleanType, LongType}
import org.apache.spark.util.TaskCompletionListener
/**
@@ -182,16 +183,17 @@ case class BroadcastHashJoinExec(
// the variables are needed even there is no matched rows
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
- val code = s"""
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
|boolean $isNull = true;
- |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|if ($matched != null) {
| ${ev.code}
| $isNull = ${ev.isNull};
| $value = ${ev.value};
|}
""".stripMargin
- ExprCode(code, isNull, value)
+ ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType))
}
}
}
@@ -486,7 +488,8 @@ case class BroadcastHashJoinExec(
s"$existsVar = true;"
}
- val resultVar = input ++ Seq(ExprCode("", "false", existsVar))
+ val resultVar = input ++ Seq(ExprCode.forNonNullValue(
+ JavaCode.variable(existsVar, BooleanType)))
if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 1465346eb802d..20ce01f4ce8cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
def append(key: Long, row: UnsafeRow): Unit = {
val sizeInBytes = row.getSizeInBytes
if (sizeInBytes >= (1 << SIZE_BITS)) {
- sys.error("Does not support row that is larger than 256M")
+ throw new UnsupportedOperationException("Does not support row that is larger than 256M")
}
if (key < minKey) {
@@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = key
}
- // There is 8 bytes for the pointer to next value
- if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
- val used = page.length
- if (used >= (1 << 30)) {
- sys.error("Can not build a HashedRelation that is larger than 8G")
- }
- ensureAcquireMemory(used * 8L * 2)
- val newPage = new Array[Long](used * 2)
- Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
- cursor - Platform.LONG_ARRAY_OFFSET)
- page = newPage
- freeMemory(used * 8L)
- }
+ grow(row.getSizeInBytes)
// copy the bytes of UnsafeRow
val offset = cursor
@@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
growArray()
} else if (numKeys > array.length / 2 * 0.75) {
// The fill ratio should be less than 0.75
- sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+ throw new UnsupportedOperationException(
+ "Cannot build HashedRelation with more than 1/3 billions unique keys")
}
}
} else {
@@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
}
+ private def grow(inputRowSize: Int): Unit = {
+ // There is 8 bytes for the pointer to next value
+ val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
+ if (neededNumWords > page.length) {
+ if (neededNumWords > (1 << 30)) {
+ throw new UnsupportedOperationException(
+ "Can not build a HashedRelation that is larger than 8G")
+ }
+ val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
+ ensureAcquireMemory(newNumWords * 8L)
+ val newPage = new Array[Long](newNumWords.toInt)
+ Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+ cursor - Platform.LONG_ARRAY_OFFSET)
+ val used = page.length
+ page = newPage
+ freeMemory(used * 8L)
+ }
+ }
+
private def growArray(): Unit = {
var old_array = array
val n = array.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 2de2f30eb05d3..f4b9d132122e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -22,11 +22,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
-ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
@@ -516,13 +516,13 @@ case class SortMergeJoinExec(
ctx.INPUT_ROW = leftRow
left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value")
- val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
- val javaType = ctx.javaType(a.dataType)
- val defaultValue = ctx.defaultValue(a.dataType)
+ val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val defaultValue = CodeGenerator.defaultValue(a.dataType)
if (a.nullable) {
val isNull = ctx.freshName("isNull")
val code =
- s"""
+ code"""
|$isNull = $leftRow.isNullAt($i);
|$value = $isNull ? $defaultValue : ($valueCode);
""".stripMargin
@@ -531,11 +531,12 @@ case class SortMergeJoinExec(
|boolean $isNull = false;
|$javaType $value = $defaultValue;
""".stripMargin
- (ExprCode(code, isNull, value), leftVarsDecl)
+ (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
+ leftVarsDecl)
} else {
- val code = s"$value = $valueCode;"
+ val code = code"$value = $valueCode;"
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
- (ExprCode(code, "false", value), leftVarsDecl)
+ (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
}
}.unzip
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index cccee63bc0680..66bcda8913738 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.util.Utils
@@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
+ val stopEarly =
+ ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
ctx.addNewFunction("stopEarly", s"""
@Override
@@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
return $stopEarly;
}
""", inlineToOuterClass = true)
- val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0
+ val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 5fcdcddca7d51..01e19bddbfb66 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -70,19 +70,13 @@ class ArrowPythonRunner(
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdout writer for $pythonExec", 0, Long.MaxValue)
-
val root = VectorSchemaRoot.create(arrowSchema, allocator)
- val arrowWriter = ArrowWriter.create(root)
-
- context.addTaskCompletionListener { _ =>
- root.close()
- allocator.close()
- }
-
- val writer = new ArrowStreamWriter(root, null, dataOut)
- writer.start()
Utils.tryWithSafeFinally {
+ val arrowWriter = ArrowWriter.create(root)
+ val writer = new ArrowStreamWriter(root, null, dataOut)
+ writer.start()
+
while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()
@@ -94,8 +88,21 @@ class ArrowPythonRunner(
writer.writeBatch()
arrowWriter.reset()
}
- } {
+ // end writes footer to the output stream and doesn't clean any resources.
+ // It could throw exception if the output stream is closed, so it should be
+ // in the try block.
writer.end()
+ } {
+ // If we close root and allocator in TaskCompletionListener, there could be a race
+ // condition where the writer thread keeps writing to the VectorSchemaRoot while
+ // it's being closed by the TaskCompletion listener.
+ // Closing root and allocator here is cleaner because root and allocator is owned
+ // by the writer thread and is only visible to the writer thread.
+ //
+ // If the writer thread is interrupted by TaskCompletionListener, it should either
+ // (1) in the try block, in which case it will get an InterruptedException when
+ // performing io, and goes into the finally block or (2) in the finally block,
+ // in which case it will ignore the interruption and close the resources.
root.close()
allocator.close()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 9d56f48249982..1e096100f7f43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
- PythonUDF.isGroupAggPandasUDF(e) ||
+ PythonUDF.isGroupedAggPandasUDF(e) ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index c798fe5a92c54..513e174c7733e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
@@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
- val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
- val schema = StructType(child.schema.drop(groupingAttributes.length))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
+ // Deduplicate the grouping attributes.
+ // If a grouping attribute also appears in data attributes, then we don't need to send the
+ // grouping attribute to Python worker. If a grouping attribute is not in data attributes,
+ // then we need to send this grouping attribute to python worker.
+ //
+ // We use argOffsets to distinguish grouping attributes and data attributes as following:
+ //
+ // argOffsets[0] is the length of grouping attributes
+ // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
+ // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
+
+ val dataAttributes = child.output.drop(groupingAttributes.length)
+ val groupingIndicesInData = groupingAttributes.map { attribute =>
+ dataAttributes.indexWhere(attribute.semanticEquals)
+ }
+
+ val groupingArgOffsets = new ArrayBuffer[Int]
+ val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
+ val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
+
+ // Non duplicate grouping attributes are added to nonDupGroupingAttributes and
+ // their offsets are 0, 1, 2 ...
+ // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
+ // their offsets are n + index, where n is the total number of non duplicate grouping
+ // attributes and index is the index in the data attributes that the grouping attribute
+ // is a duplicate of.
+
+ groupingAttributes.zip(groupingIndicesInData).foreach {
+ case (attribute, index) =>
+ if (index == -1) {
+ groupingArgOffsets += nonDupGroupingAttributes.length
+ nonDupGroupingAttributes += attribute
+ } else {
+ groupingArgOffsets += index + nonDupGroupingSize
+ }
+ }
+
+ val dataArgOffsets = nonDupGroupingAttributes.length until
+ (nonDupGroupingAttributes.length + dataAttributes.length)
+
+ val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
+
+ // Attributes after deduplication
+ val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
+ val dedupSchema = StructType.fromAttributes(dedupAttributes)
+
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
- val dropGrouping =
- UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
+ val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
- case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+ case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
@@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
- PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
new file mode 100644
index 0000000000000..a58773122922f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.locks.ReentrantLock
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python._
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{NextIterator, Utils}
+
+class PythonForeachWriter(func: PythonFunction, schema: StructType)
+ extends ForeachWriter[UnsafeRow] {
+
+ private lazy val context = TaskContext.get()
+ private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer(
+ context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length)
+ private lazy val inputRowIterator = buffer.iterator
+
+ private lazy val inputByteIterator = {
+ EvaluatePython.registerPicklers()
+ val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) }
+ new SerDeUtil.AutoBatchedPickler(objIterator)
+ }
+
+ private lazy val pythonRunner = {
+ val conf = SparkEnv.get.conf
+ val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)
+ PythonRunner(func, bufferSize, reuseWorker)
+ }
+
+ private lazy val outputIterator =
+ pythonRunner.compute(inputByteIterator, context.partitionId(), context)
+
+ override def open(partitionId: Long, version: Long): Boolean = {
+ outputIterator // initialize everything
+ TaskContext.get.addTaskCompletionListener { _ => buffer.close() }
+ true
+ }
+
+ override def process(value: UnsafeRow): Unit = {
+ buffer.add(value)
+ }
+
+ override def close(errorOrNull: Throwable): Unit = {
+ buffer.allRowsAdded()
+ if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one
+ }
+}
+
+object PythonForeachWriter {
+
+ /**
+ * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter.
+ * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader
+ * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python
+ * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator
+ * are blocking, that is, it blocks until new data is available or all data has been added.
+ *
+ * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue
+ * across memory and local disk. However, HybridRowQueue is designed to be used only with
+ * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not
+ * try to read n+1 rows if the writer has only written n rows at any point of time. This
+ * assumption is not true for PythonForeachWriter where rows may be added at a different rate as
+ * they are consumed by the python worker. Hence, to maintain the invariant of the reader being
+ * behind the writer while using HybridRowQueue, the buffer does the following
+ * - Keeps a count of the rows in the HybridRowQueue
+ * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not
+ * try to read more rows than what has been written.
+ *
+ * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed
+ * from that of ArrayBlockingQueue.
+ */
+ class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int)
+ extends Logging {
+ private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields)
+ private val lock = new ReentrantLock()
+ private val unblockRemove = lock.newCondition()
+
+ // All of these are guarded by `lock`
+ private var count = 0L
+ private var allAdded = false
+ private var exception: Throwable = null
+
+ val iterator = new NextIterator[UnsafeRow] {
+ override protected def getNext(): UnsafeRow = {
+ val row = remove()
+ if (row == null) finished = true
+ row
+ }
+ override protected def close(): Unit = { }
+ }
+
+ def add(row: UnsafeRow): Unit = withLock {
+ assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" +
+ s"[count = $count, allAdded = $allAdded, exception = $exception]")
+ count += 1
+ unblockRemove.signal()
+ logTrace(s"Added $row, $count left")
+ }
+
+ private def remove(): UnsafeRow = withLock {
+ while (count == 0 && !allAdded && exception == null) {
+ unblockRemove.await(100, TimeUnit.MILLISECONDS)
+ }
+
+ // If there was any error in the adding thread, then rethrow it in the removing thread
+ if (exception != null) throw exception
+
+ if (count > 0) {
+ val row = queue.remove()
+ assert(row != null, "HybridRowQueue.remove() returned null " +
+ s"[count = $count, allAdded = $allAdded, exception = $exception]")
+ count -= 1
+ logTrace(s"Removed $row, $count left")
+ row
+ } else {
+ null
+ }
+ }
+
+ def allRowsAdded(): Unit = withLock {
+ allAdded = true
+ unblockRemove.signal()
+ }
+
+ def close(): Unit = { queue.close() }
+
+ private def withLock[T](f: => T): T = {
+ lock.lockInterruptibly()
+ try { f } catch {
+ case e: Throwable =>
+ if (exception == null) exception = e
+ throw e
+ } finally { lock.unlock() }
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
new file mode 100644
index 0000000000000..c76832a1a3829
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+case class WindowInPandasExec(
+ windowExpression: Seq[NamedExpression],
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ child: SparkPlan) extends UnaryExecNode {
+
+ override def output: Seq[Attribute] =
+ child.output ++ windowExpression.map(_.toAttribute)
+
+ override def requiredChildDistribution: Seq[Distribution] = {
+ if (partitionSpec.isEmpty) {
+ // Only show warning when the number of bytes is larger than 100 MB?
+ logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ + "partition, this can cause serious performance degradation.")
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(partitionSpec) :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+ udf.children match {
+ case Seq(u: PythonUDF) =>
+ val (chained, children) = collectFunctions(u)
+ (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+ case children =>
+ // There should not be any other UDFs, or the children can't be evaluated directly.
+ assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+ (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+ }
+ }
+
+ /**
+ * Create the resulting projection.
+ *
+ * This method uses Code Generation. It can only be used on the executor side.
+ *
+ * @param expressions unbound ordered function expressions.
+ * @return the final resulting projection.
+ */
+ private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
+ val references = expressions.zipWithIndex.map { case (e, i) =>
+ // Results of window expressions will be on the right side of child's output
+ BoundReference(child.output.size + i, e.dataType, e.nullable)
+ }
+ val unboundToRefMap = expressions.zip(references).toMap
+ val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
+ UnsafeProjection.create(
+ child.output ++ patchedWindowExpression,
+ child.output)
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val inputRDD = child.execute()
+
+ val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+ val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
+
+ // Extract window expressions and window functions
+ val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
+
+ val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
+
+ val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+
+ // Filter child output attributes down to only those that are UDF inputs.
+ // Also eliminate duplicate UDF inputs.
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+
+ // Schema of input rows to the python runner
+ val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+ StructField(s"_$i", dt)
+ })
+
+ inputRDD.mapPartitionsInternal { iter =>
+ val context = TaskContext.get()
+
+ val grouped = if (partitionSpec.isEmpty) {
+ // Use an empty unsafe row as a place holder for the grouping key
+ Iterator((new UnsafeRow(), iter))
+ } else {
+ GroupedIterator(iter, partitionSpec, child.output)
+ }
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = HybridRowQueue(context.taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+ context.addTaskCompletionListener { _ =>
+ queue.close()
+ }
+
+ val inputProj = UnsafeProjection.create(allInputs, child.output)
+ val pythonInput = grouped.map { case (_, rows) =>
+ rows.map { row =>
+ queue.add(row.asInstanceOf[UnsafeRow])
+ inputProj(row)
+ }
+ }
+
+ val windowFunctionResult = new ArrowPythonRunner(
+ pyFuncs, bufferSize, reuseWorker,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+ argOffsets, windowInputSchema,
+ sessionLocalTimeZone, pandasRespectSessionTimeZone)
+ .compute(pythonInput, context.partitionId(), context)
+
+ val joined = new JoinedRow
+ val resultProj = createResultProjection(expressions)
+
+ windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
+ val leftRow = queue.remove()
+ val joinedRow = joined(leftRow, windowOutput)
+ resultProj(joinedRow)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala
new file mode 100644
index 0000000000000..606ba250ad9d2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{FileNotFoundException, IOException, OutputStream}
+import java.util.{EnumSet, UUID}
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs._
+import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs}
+import org.apache.hadoop.fs.permission.FsPermission
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+/**
+ * An interface to abstract out all operation related to streaming checkpoints. Most importantly,
+ * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a
+ * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and
+ * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations
+ * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or
+ * without overwrite.
+ *
+ * This higher-level interface above the Hadoop FileSystem is necessary because
+ * different implementation of FileSystem/FileContext may have different combination of operations
+ * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename,
+ * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while
+ * keeping the usage simple (`createAtomic` -> `close` or `cancel`).
+ */
+trait CheckpointFileManager {
+
+ import org.apache.spark.sql.execution.streaming.CheckpointFileManager._
+
+ /**
+ * Create a file and make its contents available atomically after the output stream is closed.
+ *
+ * @param path Path to create
+ * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to
+ * overwrite the file if it already exists. It should not throw
+ * any exception if the file exists. However, if false, then the
+ * implementation must not overwrite if the file alraedy exists and
+ * must throw `FileAlreadyExistsException` in that case.
+ */
+ def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream
+
+ /** Open a file for reading, or throw exception if it does not exist. */
+ def open(path: Path): FSDataInputStream
+
+ /** List the files in a path that match a filter. */
+ def list(path: Path, filter: PathFilter): Array[FileStatus]
+
+ /** List all the files in a path. */
+ def list(path: Path): Array[FileStatus] = {
+ list(path, new PathFilter { override def accept(path: Path): Boolean = true })
+ }
+
+ /** Make directory at the give path and all its parent directories as needed. */
+ def mkdirs(path: Path): Unit
+
+ /** Whether path exists */
+ def exists(path: Path): Boolean
+
+ /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */
+ def delete(path: Path): Unit
+
+ /** Is the default file system this implementation is operating on the local file system. */
+ def isLocal: Boolean
+}
+
+object CheckpointFileManager extends Logging {
+
+ /**
+ * Additional methods in CheckpointFileManager implementations that allows
+ * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename
+ */
+ sealed trait RenameHelperMethods { self => CheckpointFileManager
+ /** Create a file with overwrite. */
+ def createTempFile(path: Path): FSDataOutputStream
+
+ /**
+ * Rename a file.
+ *
+ * @param srcPath Source path to rename
+ * @param dstPath Destination path to rename to
+ * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to
+ * overwrite the file if it already exists. It should not throw
+ * any exception if the file exists. However, if false, then the
+ * implementation must not overwrite if the file alraedy exists and
+ * must throw `FileAlreadyExistsException` in that case.
+ */
+ def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit
+ }
+
+ /**
+ * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used
+ * mainly by `CheckpointFileManager.createAtomic` to write a file atomically.
+ *
+ * @see [[CheckpointFileManager]].
+ */
+ abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream)
+ extends FSDataOutputStream(underlyingStream, null) {
+ /** Cancel the `underlyingStream` and ensure that the output file is not generated. */
+ def cancel(): Unit
+ }
+
+ /**
+ * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing
+ * to a temporary file and then renames.
+ */
+ sealed class RenameBasedFSDataOutputStream(
+ fm: CheckpointFileManager with RenameHelperMethods,
+ finalPath: Path,
+ tempPath: Path,
+ overwriteIfPossible: Boolean)
+ extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) {
+
+ def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = {
+ this(fm, path, generateTempPath(path), overwrite)
+ }
+
+ logInfo(s"Writing atomically to $finalPath using temp file $tempPath")
+ @volatile private var terminated = false
+
+ override def close(): Unit = synchronized {
+ try {
+ if (terminated) return
+ underlyingStream.close()
+ try {
+ fm.renameTempFile(tempPath, finalPath, overwriteIfPossible)
+ } catch {
+ case fe: FileAlreadyExistsException =>
+ logWarning(
+ s"Failed to rename temp file $tempPath to $finalPath because file exists", fe)
+ if (!overwriteIfPossible) throw fe
+ }
+ logInfo(s"Renamed temp file $tempPath to $finalPath")
+ } finally {
+ terminated = true
+ }
+ }
+
+ override def cancel(): Unit = synchronized {
+ try {
+ if (terminated) return
+ underlyingStream.close()
+ fm.delete(tempPath)
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Error cancelling write to $finalPath", e)
+ } finally {
+ terminated = true
+ }
+ }
+ }
+
+
+ /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */
+ def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = {
+ val fileManagerClass = hadoopConf.get(
+ SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key)
+ if (fileManagerClass != null) {
+ return Utils.classForName(fileManagerClass)
+ .getConstructor(classOf[Path], classOf[Configuration])
+ .newInstance(path, hadoopConf)
+ .asInstanceOf[CheckpointFileManager]
+ }
+ try {
+ // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename()
+ // gives atomic renames, which is what we rely on for the default implementation
+ // `CheckpointFileManager.createAtomic`.
+ new FileContextBasedCheckpointFileManager(path, hadoopConf)
+ } catch {
+ case e: UnsupportedFileSystemException =>
+ logWarning(
+ "Could not use FileContext API for managing Structured Streaming checkpoint files at " +
+ s"$path. Using FileSystem API instead for managing log files. If the implementation " +
+ s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" +
+ s"your Structured Streaming is not guaranteed.")
+ new FileSystemBasedCheckpointFileManager(path, hadoopConf)
+ }
+ }
+
+ private def generateTempPath(path: Path): Path = {
+ val tc = org.apache.spark.TaskContext.get
+ val tid = if (tc != null) ".TID" + tc.taskAttemptId else ""
+ new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp")
+ }
+}
+
+
+/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */
+class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration)
+ extends CheckpointFileManager with RenameHelperMethods with Logging {
+
+ import CheckpointFileManager._
+
+ protected val fs = path.getFileSystem(hadoopConf)
+
+ override def list(path: Path, filter: PathFilter): Array[FileStatus] = {
+ fs.listStatus(path, filter)
+ }
+
+ override def mkdirs(path: Path): Unit = {
+ fs.mkdirs(path, FsPermission.getDirDefault)
+ }
+
+ override def createTempFile(path: Path): FSDataOutputStream = {
+ fs.create(path, true)
+ }
+
+ override def createAtomic(
+ path: Path,
+ overwriteIfPossible: Boolean): CancellableFSDataOutputStream = {
+ new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible)
+ }
+
+ override def open(path: Path): FSDataInputStream = {
+ fs.open(path)
+ }
+
+ override def exists(path: Path): Boolean = {
+ try
+ return fs.getFileStatus(path) != null
+ catch {
+ case e: FileNotFoundException =>
+ return false
+ }
+ }
+
+ override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = {
+ if (!overwriteIfPossible && fs.exists(dstPath)) {
+ throw new FileAlreadyExistsException(
+ s"Failed to rename $srcPath to $dstPath as destination already exists")
+ }
+
+ if (!fs.rename(srcPath, dstPath)) {
+ // FileSystem.rename() returning false is very ambiguous as it can be for many reasons.
+ // This tries to make a best effort attempt to return the most appropriate exception.
+ if (fs.exists(dstPath)) {
+ if (!overwriteIfPossible) {
+ throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists")
+ }
+ } else if (!fs.exists(srcPath)) {
+ throw new FileNotFoundException(s"Failed to rename as $srcPath was not found")
+ } else {
+ val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false"
+ logWarning(msg)
+ throw new IOException(msg)
+ }
+ }
+ }
+
+ override def delete(path: Path): Unit = {
+ try {
+ fs.delete(path, true)
+ } catch {
+ case e: FileNotFoundException =>
+ logInfo(s"Failed to delete $path as it does not exist")
+ // ignore if file has already been deleted
+ }
+ }
+
+ override def isLocal: Boolean = fs match {
+ case _: LocalFileSystem | _: RawLocalFileSystem => true
+ case _ => false
+ }
+}
+
+
+/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */
+class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration)
+ extends CheckpointFileManager with RenameHelperMethods with Logging {
+
+ import CheckpointFileManager._
+
+ private val fc = if (path.toUri.getScheme == null) {
+ FileContext.getFileContext(hadoopConf)
+ } else {
+ FileContext.getFileContext(path.toUri, hadoopConf)
+ }
+
+ override def list(path: Path, filter: PathFilter): Array[FileStatus] = {
+ fc.util.listStatus(path, filter)
+ }
+
+ override def mkdirs(path: Path): Unit = {
+ fc.mkdir(path, FsPermission.getDirDefault, true)
+ }
+
+ override def createTempFile(path: Path): FSDataOutputStream = {
+ import CreateFlag._
+ import Options._
+ fc.create(
+ path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled()))
+ }
+
+ override def createAtomic(
+ path: Path,
+ overwriteIfPossible: Boolean): CancellableFSDataOutputStream = {
+ new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible)
+ }
+
+ override def open(path: Path): FSDataInputStream = {
+ fc.open(path)
+ }
+
+ override def exists(path: Path): Boolean = {
+ fc.util.exists(path)
+ }
+
+ override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = {
+ import Options.Rename._
+ fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE)
+ }
+
+
+ override def delete(path: Path): Unit = {
+ try {
+ fc.delete(path, true)
+ } catch {
+ case e: FileNotFoundException =>
+ // ignore if file has already been deleted
+ }
+ }
+
+ override def isLocal: Boolean = fc.getDefaultFileSystem match {
+ case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs
+ case _ => false
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index 2715fa93d0e98..b3d12f67b5d63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter}
+import org.apache.spark.util.SerializableConfiguration
object FileStreamSink extends Logging {
// The name of the subdirectory that is used to store metadata about which files are valid.
@@ -42,9 +43,11 @@ object FileStreamSink extends Logging {
try {
val hdfsPath = new Path(singlePath)
val fs = hdfsPath.getFileSystem(hadoopConf)
- val metadataPath = new Path(hdfsPath, metadataDir)
- val res = fs.exists(metadataPath)
- res
+ if (fs.isDirectory(hdfsPath)) {
+ fs.exists(new Path(hdfsPath, metadataDir))
+ } else {
+ false
+ }
} catch {
case NonFatal(e) =>
logWarning(s"Error while looking for metadata directory.")
@@ -95,6 +98,11 @@ class FileStreamSink(
new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString)
private val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
+ val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+ new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics)
+ }
+
override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
logInfo(s"Skipping already committed batch $batchId")
@@ -129,7 +137,7 @@ class FileStreamSink(
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
- statsTrackers = Nil,
+ statsTrackers = Seq(basicWriteJobStatsTracker),
options = options)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 80769d728b8f1..8e82cccbc8fa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec(
override def keyExpressions: Seq[Attribute] = groupingAttributes
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ timeoutConf match {
+ case ProcessingTimeTimeout =>
+ true // Always run batches to process timeouts
+ case EventTimeTimeout =>
+ // Process another non-data batch only if the watermark has changed in this executed plan
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+ case _ =>
+ false
+ }
+ }
+
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
@@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec(
case _ =>
iter
}
-
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
@@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
- val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+ val timingOutPairs = store.getRange(None, None).filter { rowPair =>
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
}
- timingOutKeys.flatMap { rowPair =>
+ timingOutPairs.flatMap { rowPair =>
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
}
} else Iterator.empty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
deleted file mode 100644
index 2cc54107f8b83..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
-import org.apache.spark.sql.catalyst.encoders.encoderFor
-
-/**
- * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
- * [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @tparam T The expected type of the sink.
- */
-class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
-
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
- // This logic should've been as simple as:
- // ```
- // data.as[T].foreachPartition { iter => ... }
- // ```
- //
- // Unfortunately, doing that would just break the incremental planing. The reason is,
- // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will
- // create a new plan. Because StreamExecution uses the existing plan to collect metrics and
- // update watermark, we should never create a new plan. Otherwise, metrics and watermark are
- // updated in the new plan, and StreamExecution cannot retrieval them.
- //
- // Hence, we need to manually convert internal rows to objects using encoder.
- val encoder = encoderFor[T].resolveAndBind(
- data.logicalPlan.output,
- data.sparkSession.sessionState.analyzer)
- data.queryExecution.toRdd.foreachPartition { iter =>
- if (writer.open(TaskContext.getPartitionId(), batchId)) {
- try {
- while (iter.hasNext) {
- writer.process(encoder.fromRow(iter.next()))
- }
- } catch {
- case e: Throwable =>
- writer.close(e)
- throw e
- }
- writer.close(null)
- } else {
- writer.close(null)
- }
- }
- }
-
- override def toString(): String = "ForeachSink"
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index 00bc215a5dc8c..bd0a46115ceb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]],
"Should not create a log with type Seq, use Arrays instead - see SPARK-17372")
- import HDFSMetadataLog._
-
val metadataPath = new Path(path)
- protected val fileManager = createFileManager()
+
+ protected val fileManager =
+ CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf)
if (!fileManager.exists(metadataPath)) {
fileManager.mkdirs(metadataPath)
@@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
require(metadata != null, "'null' metadata cannot written to a metadata log")
get(batchId).map(_ => false).getOrElse {
// Only write metadata when the batch has not yet been written
- writeBatch(batchId, metadata)
+ writeBatchToFile(metadata, batchIdToPath(batchId))
true
}
}
- private def writeTempBatch(metadata: T): Option[Path] = {
- while (true) {
- val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp")
- try {
- val output = fileManager.create(tempPath)
- try {
- serialize(metadata, output)
- return Some(tempPath)
- } finally {
- output.close()
- }
- } catch {
- case e: FileAlreadyExistsException =>
- // Failed to create "tempPath". There are two cases:
- // 1. Someone is creating "tempPath" too.
- // 2. This is a restart. "tempPath" has already been created but not moved to the final
- // batch file (not committed).
- //
- // For both cases, the batch has not yet been committed. So we can retry it.
- //
- // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use
- // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a
- // big problem because it requires the attacker must have the permission to write the
- // metadata path. In addition, the old Streaming also have this issue, people can create
- // malicious checkpoint files to crash a Streaming application too.
- }
- }
- None
- }
-
- /**
- * Write a batch to a temp file then rename it to the batch file.
+ /** Write a batch to a temp file then rename it to the batch file.
*
* There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a
* valid behavior, we still need to prevent it from destroying the files.
*/
- private def writeBatch(batchId: Long, metadata: T): Unit = {
- val tempPath = writeTempBatch(metadata).getOrElse(
- throw new IllegalStateException(s"Unable to create temp batch file $batchId"))
+ private def writeBatchToFile(metadata: T, path: Path): Unit = {
+ val output = fileManager.createAtomic(path, overwriteIfPossible = false)
try {
- // Try to commit the batch
- // It will fail if there is an existing file (someone has committed the batch)
- logDebug(s"Attempting to write log #${batchIdToPath(batchId)}")
- fileManager.rename(tempPath, batchIdToPath(batchId))
-
- // SPARK-17475: HDFSMetadataLog should not leak CRC files
- // If the underlying filesystem didn't rename the CRC file, delete it.
- val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc")
- if (fileManager.exists(crcPath)) fileManager.delete(crcPath)
+ serialize(metadata, output)
+ output.close()
} catch {
case e: FileAlreadyExistsException =>
- // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch.
- // So throw an exception to tell the user this is not a valid behavior.
+ output.cancel()
+ // If next batch file already exists, then another concurrently running query has
+ // written it.
throw new ConcurrentModificationException(
- s"Multiple HDFSMetadataLog are using $path", e)
- } finally {
- fileManager.delete(tempPath)
- }
- }
-
- /**
- * @return the deserialized metadata in a batch file, or None if file not exist.
- * @throws IllegalArgumentException when path does not point to a batch file.
- */
- def get(batchFile: Path): Option[T] = {
- if (fileManager.exists(batchFile)) {
- if (isBatchFile(batchFile)) {
- get(pathToBatchId(batchFile))
- } else {
- throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!")
- }
- } else {
- None
+ s"Multiple streaming queries are concurrently using $path", e)
+ case e: Throwable =>
+ output.cancel()
+ throw e
}
}
@@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
(endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get)
}.sorted
- verifyBatchIds(batchIds, startId, endId)
+ HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId)
batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map {
case (batchId, metadataOption) =>
@@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
}
}
- private def createFileManager(): FileManager = {
- val hadoopConf = sparkSession.sessionState.newHadoopConf()
- try {
- new FileContextManager(metadataPath, hadoopConf)
- } catch {
- case e: UnsupportedFileSystemException =>
- logWarning("Could not use FileContext API for managing metadata log files at path " +
- s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " +
- s"inconsistent under failures.")
- new FileSystemManager(metadataPath, hadoopConf)
- }
- }
-
/**
* Parse the log version from the given `text` -- will throw exception when the parsed version
* exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1",
@@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
object HDFSMetadataLog {
- /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */
- trait FileManager {
-
- /** List the files in a path that match a filter. */
- def list(path: Path, filter: PathFilter): Array[FileStatus]
-
- /** Make directory at the give path and all its parent directories as needed. */
- def mkdirs(path: Path): Unit
-
- /** Whether path exists */
- def exists(path: Path): Boolean
-
- /** Open a file for reading, or throw exception if it does not exist. */
- def open(path: Path): FSDataInputStream
-
- /** Create path, or throw exception if it already exists */
- def create(path: Path): FSDataOutputStream
-
- /**
- * Atomically rename path, or throw exception if it cannot be done.
- * Should throw FileNotFoundException if srcPath does not exist.
- * Should throw FileAlreadyExistsException if destPath already exists.
- */
- def rename(srcPath: Path, destPath: Path): Unit
-
- /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */
- def delete(path: Path): Unit
- }
-
- /**
- * Default implementation of FileManager using newer FileContext API.
- */
- class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager {
- private val fc = if (path.toUri.getScheme == null) {
- FileContext.getFileContext(hadoopConf)
- } else {
- FileContext.getFileContext(path.toUri, hadoopConf)
- }
-
- override def list(path: Path, filter: PathFilter): Array[FileStatus] = {
- fc.util.listStatus(path, filter)
- }
-
- override def rename(srcPath: Path, destPath: Path): Unit = {
- fc.rename(srcPath, destPath)
- }
-
- override def mkdirs(path: Path): Unit = {
- fc.mkdir(path, FsPermission.getDirDefault, true)
- }
-
- override def open(path: Path): FSDataInputStream = {
- fc.open(path)
- }
-
- override def create(path: Path): FSDataOutputStream = {
- fc.create(path, EnumSet.of(CreateFlag.CREATE))
- }
-
- override def exists(path: Path): Boolean = {
- fc.util().exists(path)
- }
-
- override def delete(path: Path): Unit = {
- try {
- fc.delete(path, true)
- } catch {
- case e: FileNotFoundException =>
- // ignore if file has already been deleted
- }
- }
- }
-
- /**
- * Implementation of FileManager using older FileSystem API. Note that this implementation
- * cannot provide atomic renaming of paths, hence can lead to consistency issues. This
- * should be used only as a backup option, when FileContextManager cannot be used.
- */
- class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager {
- private val fs = path.getFileSystem(hadoopConf)
-
- override def list(path: Path, filter: PathFilter): Array[FileStatus] = {
- fs.listStatus(path, filter)
- }
-
- /**
- * Rename a path. Note that this implementation is not atomic.
- * @throws FileNotFoundException if source path does not exist.
- * @throws FileAlreadyExistsException if destination path already exists.
- * @throws IOException if renaming fails for some unknown reason.
- */
- override def rename(srcPath: Path, destPath: Path): Unit = {
- if (!fs.exists(srcPath)) {
- throw new FileNotFoundException(s"Source path does not exist: $srcPath")
- }
- if (fs.exists(destPath)) {
- throw new FileAlreadyExistsException(s"Destination path already exists: $destPath")
- }
- if (!fs.rename(srcPath, destPath)) {
- throw new IOException(s"Failed to rename $srcPath to $destPath")
- }
- }
-
- override def mkdirs(path: Path): Unit = {
- fs.mkdirs(path, FsPermission.getDirDefault)
- }
-
- override def open(path: Path): FSDataInputStream = {
- fs.open(path)
- }
-
- override def create(path: Path): FSDataOutputStream = {
- fs.create(path, false)
- }
-
- override def exists(path: Path): Boolean = {
- fs.exists(path)
- }
-
- override def delete(path: Path): Unit = {
- try {
- fs.delete(path, true)
- } catch {
- case e: FileNotFoundException =>
- // ignore if file has already been deleted
- }
- }
- }
-
/**
* Verify if batchIds are continuous and between `startId` and `endId`.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index a10ed5f2df1b5..c480b96626f84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -62,7 +62,7 @@ class IncrementalExecution(
StreamingDeduplicationStrategy :: Nil
}
- private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
+ private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
@@ -143,4 +143,14 @@ class IncrementalExecution(
/** No need assert supported, as this check has already been done */
override def assertSupported(): Unit = { }
+
+ /**
+ * Should the MicroBatchExecution run another batch based on this execution and the current
+ * updated metadata.
+ */
+ def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ executedPlan.collect {
+ case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata)
+ }.exists(_ == true)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
index 5f0b195fcfcb8..3ff5b86ac45d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.execution.streaming
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+
/**
* A simple offset for sources that produce a single linear stream of data.
*/
-case class LongOffset(offset: Long) extends Offset {
+case class LongOffset(offset: Long) extends OffsetV2 {
override val json = offset.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
index 1da703cefd8ea..5cacdd070b735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
@@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType
* A [[FileIndex]] that generates the list of files to processing by reading them from the
* metadata log files generated by the [[FileStreamSink]].
*
- * @param userPartitionSchema an optional partition schema that will be use to provide types for
- * the discovered partitions
+ * @param userSpecifiedSchema an optional user specified schema that will be use to provide
+ * types for the discovered partitions
*/
class MetadataLogFileIndex(
sparkSession: SparkSession,
path: Path,
- userPartitionSchema: Option[StructType])
- extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) {
+ userSpecifiedSchema: Option[StructType])
+ extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) {
private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
logInfo(s"Reading streaming file log from $metadataDirectory")
@@ -51,7 +51,7 @@ class MetadataLogFileIndex(
}
override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = {
- allFilesFromLog.toArray.groupBy(_.getPath.getParent)
+ allFilesFromLog.groupBy(_.getPath.getParent)
}
override def rootPaths: Seq[Path] = path :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index d9aa8573ba930..17ffa2a517312 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -20,19 +20,18 @@ package org.apache.spark.sql.execution.streaming
import java.util.Optional
import scala.collection.JavaConverters._
-import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
+import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
-import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow}
+import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.{Clock, Utils}
@@ -53,12 +52,17 @@ class MicroBatchExecution(
@volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty
+ private val readerToDataSourceMap =
+ MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])]
+
private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
case OneTimeTrigger => OneTimeExecutor()
case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
}
+ private val watermarkTracker = new WatermarkTracker()
+
override lazy val logicalPlan: LogicalPlan = {
assert(queryExecutionThread eq Thread.currentThread,
"logicalPlan must be initialized in QueryExecutionThread " +
@@ -73,27 +77,37 @@ class MicroBatchExecution(
// Note that we have to use the previous `output` as attributes in StreamingExecutionRelation,
// since the existing logical plan has already used those attributes. The per-microbatch
// transformation is responsible for replacing attributes with their final values.
+
+ val disabledSources =
+ sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",")
+
val _logicalPlan = analyzedPlan.transform {
- case streamingRelation@StreamingRelation(dataSource, _, output) =>
+ case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) =>
toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- val source = dataSource.createSource(metadataPath)
+ val source = dataSourceV1.createSource(metadataPath)
nextSourceId += 1
+ logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]")
StreamingExecutionRelation(source, output)(sparkSession)
})
- case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) =>
+ case s @ StreamingRelationV2(
+ dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if
+ !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- val reader = source.createMicroBatchReader(
+ val reader = dataSourceV2.createMicroBatchReader(
Optional.empty(), // user specified schema
metadataPath,
new DataSourceOptions(options.asJava))
nextSourceId += 1
+ readerToDataSourceMap(reader) = dataSourceV2 -> options
+ logInfo(s"Using MicroBatchReader [$reader] from " +
+ s"DataSourceV2 named '$sourceName' [$dataSourceV2]")
StreamingExecutionRelation(reader, output)(sparkSession)
})
- case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) =>
+ case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
@@ -103,6 +117,7 @@ class MicroBatchExecution(
}
val source = v1Relation.get.dataSource.createSource(metadataPath)
nextSourceId += 1
+ logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]")
StreamingExecutionRelation(source, output)(sparkSession)
})
}
@@ -111,44 +126,87 @@ class MicroBatchExecution(
_logicalPlan
}
+ /**
+ * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed
+ * (i.e. written to the offsetLog) and is ready for execution.
+ */
+ private var isCurrentBatchConstructed = false
+
+ /**
+ * Signals to the thread executing micro-batches that it should stop running after the next
+ * batch. This method blocks until the thread stops running.
+ */
+ override def stop(): Unit = {
+ // Set the state to TERMINATED so that the batching thread knows that it was interrupted
+ // intentionally
+ state.set(TERMINATED)
+ if (queryExecutionThread.isAlive) {
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
+ queryExecutionThread.interrupt()
+ queryExecutionThread.join()
+ // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
+ }
+ logInfo(s"Query $prettyIdString was stopped")
+ }
+
/**
* Repeatedly attempts to run batches as data arrives.
*/
protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
- triggerExecutor.execute(() => {
- startTrigger()
+ val noDataBatchesEnabled =
+ sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled
+
+ triggerExecutor.execute(() => {
if (isActive) {
+ var currentBatchHasNewData = false // Whether the current batch had new data
+
+ startTrigger()
+
reportTimeTaken("triggerExecution") {
+ // We'll do this initialization only once every start / restart
if (currentBatchId < 0) {
- // We'll do this initialization only once
populateStartOffsets(sparkSessionForStream)
- sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
- logDebug(s"Stream running from $committedOffsets to $availableOffsets")
- } else {
- constructNextBatch()
+ logInfo(s"Stream started from $committedOffsets")
+ }
+
+ // Set this before calling constructNextBatch() so any Spark jobs executed by sources
+ // while getting new data have the correct description
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
+
+ // Try to construct the next batch. This will return true only if the next batch is
+ // ready and runnable. Note that the current batch may be runnable even without
+ // new data to process as `constructNextBatch` may decide to run a batch for
+ // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data
+ // is available or not.
+ if (!isCurrentBatchConstructed) {
+ isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled)
}
- if (dataAvailable) {
- currentStatus = currentStatus.copy(isDataAvailable = true)
- updateStatusMessage("Processing new data")
+
+ // Remember whether the current batch has data or not. This will be required later
+ // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed
+ // to false as the batch would have already processed the available data.
+ currentBatchHasNewData = isNewDataAvailable
+
+ currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable)
+ if (isCurrentBatchConstructed) {
+ if (currentBatchHasNewData) updateStatusMessage("Processing new data")
+ else updateStatusMessage("No new data but cleaning up state")
runBatch(sparkSessionForStream)
+ } else {
+ updateStatusMessage("Waiting for data to arrive")
}
}
- // Report trigger as finished and construct progress object.
- finishTrigger(dataAvailable)
- if (dataAvailable) {
- // Update committed offsets.
- commitLog.add(currentBatchId)
- committedOffsets ++= availableOffsets
- logDebug(s"batch ${currentBatchId} committed")
- // We'll increase currentBatchId after we complete processing current batch's data
+
+ finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded
+
+ // If the current batch has been executed, then increment the batch id and reset flag.
+ // Otherwise, there was no data to execute the batch and sleep for some time
+ if (isCurrentBatchConstructed) {
currentBatchId += 1
- sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
- } else {
- currentStatus = currentStatus.copy(isDataAvailable = false)
- updateStatusMessage("Waiting for data to arrive")
- Thread.sleep(pollingDelayMs)
- }
+ isCurrentBatchConstructed = false
+ } else Thread.sleep(pollingDelayMs)
}
updateStatusMessage("Waiting for next trigger")
isActive
@@ -183,6 +241,7 @@ class MicroBatchExecution(
/* First assume that we are re-executing the latest known batch
* in the offset log */
currentBatchId = latestBatchId
+ isCurrentBatchConstructed = true
availableOffsets = nextOffsets.toStreamProgress(sources)
/* Initialize committed offsets to a committed batch, which at this
* is the second latest batch id in the offset log. */
@@ -198,6 +257,7 @@ class MicroBatchExecution(
OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
+ watermarkTracker.setWatermark(metadata.batchWatermarkMs)
}
/* identify the current batch id: if commit log indicates we successfully processed the
@@ -220,9 +280,9 @@ class MicroBatchExecution(
// here, so we do nothing here.
}
currentBatchId = latestCommittedBatchId + 1
+ isCurrentBatchConstructed = false
committedOffsets ++= availableOffsets
// Construct a new batch be recomputing availableOffsets
- constructNextBatch()
} else if (latestCommittedBatchId < latestBatchId - 1) {
logWarning(s"Batch completion log latest batch id is " +
s"${latestCommittedBatchId}, which is not trailing " +
@@ -230,19 +290,18 @@ class MicroBatchExecution(
}
case None => logInfo("no commit log present")
}
- logDebug(s"Resuming at batch $currentBatchId with committed offsets " +
+ logInfo(s"Resuming at batch $currentBatchId with committed offsets " +
s"$committedOffsets and available offsets $availableOffsets")
case None => // We are starting this stream for the first time.
logInfo(s"Starting new streaming query.")
currentBatchId = 0
- constructNextBatch()
}
}
/**
* Returns true if there is any new data available to be processed.
*/
- private def dataAvailable: Boolean = {
+ private def isNewDataAvailable: Boolean = {
availableOffsets.exists {
case (source, available) =>
committedOffsets
@@ -253,92 +312,65 @@ class MicroBatchExecution(
}
/**
- * Queries all of the sources to see if any new data is available. When there is new data the
- * batchId counter is incremented and a new log entry is written with the newest offsets.
+ * Attempts to construct a batch according to:
+ * - Availability of new data
+ * - Need for timeouts and state cleanups in stateful operators
+ *
+ * Returns true only if the next batch should be executed.
+ *
+ * Here is the high-level logic on how this constructs the next batch.
+ * - Check each source whether new data is available
+ * - Updated the query's metadata and check using the last execution whether there is any need
+ * to run another batch (for state clean up, etc.)
+ * - If either of the above is true, then construct the next batch by committing to the offset
+ * log that range of offsets that the next batch will process.
*/
- private def constructNextBatch(): Unit = {
- // Check to see what new data is available.
- val hasNewData = {
- awaitProgressLock.lock()
- try {
- // Generate a map from each unique source to the next available offset.
- val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map {
- case s: Source =>
- updateStatusMessage(s"Getting offsets from $s")
- reportTimeTaken("getOffset") {
- (s, s.getOffset)
- }
- case s: MicroBatchReader =>
- updateStatusMessage(s"Getting offsets from $s")
- reportTimeTaken("getOffset") {
- // Once v1 streaming source execution is gone, we can refactor this away.
- // For now, we set the range here to get the source to infer the available end offset,
- // get that offset, and then set the range again when we later execute.
- s.setOffsetRange(
- toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
- Optional.empty())
-
- (s, Some(s.getEndOffset))
- }
- }.toMap
- availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
-
- if (dataAvailable) {
- true
- } else {
- noNewData = true
- false
- }
- } finally {
- awaitProgressLock.unlock()
- }
- }
- if (hasNewData) {
- var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs
- // Update the eventTime watermarks if we find any in the plan.
- if (lastExecution != null) {
- lastExecution.executedPlan.collect {
- case e: EventTimeWatermarkExec => e
- }.zipWithIndex.foreach {
- case (e, index) if e.eventTimeStats.value.count > 0 =>
- logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
- val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs
- val prevWatermarkMs = watermarkMsMap.get(index)
- if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) {
- watermarkMsMap.put(index, newWatermarkMs)
- }
-
- // Populate 0 if we haven't seen any data yet for this watermark node.
- case (_, index) =>
- if (!watermarkMsMap.isDefinedAt(index)) {
- watermarkMsMap.put(index, 0)
- }
+ private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked {
+ if (isCurrentBatchConstructed) return true
+
+ // Generate a map from each unique source to the next available offset.
+ val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map {
+ case s: Source =>
+ updateStatusMessage(s"Getting offsets from $s")
+ reportTimeTaken("getOffset") {
+ (s, s.getOffset)
}
-
- // Update the global watermark to the minimum of all watermark nodes.
- // This is the safest option, because only the global watermark is fault-tolerant. Making
- // it the minimum of all individual watermarks guarantees it will never advance past where
- // any individual watermark operator would be if it were in a plan by itself.
- if(!watermarkMsMap.isEmpty) {
- val newWatermarkMs = watermarkMsMap.minBy(_._2)._2
- if (newWatermarkMs > batchWatermarkMs) {
- logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
- batchWatermarkMs = newWatermarkMs
- } else {
- logDebug(
- s"Event time didn't move: $newWatermarkMs < " +
- s"$batchWatermarkMs")
- }
+ case s: MicroBatchReader =>
+ updateStatusMessage(s"Getting offsets from $s")
+ reportTimeTaken("setOffsetRange") {
+ // Once v1 streaming source execution is gone, we can refactor this away.
+ // For now, we set the range here to get the source to infer the available end offset,
+ // get that offset, and then set the range again when we later execute.
+ s.setOffsetRange(
+ toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
+ Optional.empty())
}
- }
- offsetSeqMetadata = offsetSeqMetadata.copy(
- batchWatermarkMs = batchWatermarkMs,
- batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds
+ val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
+ (s, Option(currentOffset))
+ }.toMap
+ availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
+
+ // Update the query metadata
+ offsetSeqMetadata = offsetSeqMetadata.copy(
+ batchWatermarkMs = watermarkTracker.currentWatermark,
+ batchTimestampMs = triggerClock.getTimeMillis())
+
+ // Check whether next batch should be constructed
+ val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled &&
+ Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata))
+ val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch
+ logTrace(
+ s"noDataBatchesEnabled = $noDataBatchesEnabled, " +
+ s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " +
+ s"isNewDataAvailable = $isNewDataAvailable, " +
+ s"shouldConstructNextBatch = $shouldConstructNextBatch")
+
+ if (shouldConstructNextBatch) {
+ // Commit the next batch offset range to the offset log
updateStatusMessage("Writing offsets to log")
reportTimeTaken("walCommit") {
- assert(offsetLog.add(
- currentBatchId,
+ assert(offsetLog.add(currentBatchId,
availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)),
s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
logInfo(s"Committed offsets for batch $currentBatchId. " +
@@ -359,7 +391,7 @@ class MicroBatchExecution(
reader.commit(reader.deserializeOffset(off.json))
}
} else {
- throw new IllegalStateException(s"batch $currentBatchId doesn't exist")
+ throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist")
}
}
@@ -370,15 +402,12 @@ class MicroBatchExecution(
commitLog.purge(currentBatchId - minLogEntriesToMaintain)
}
}
+ noNewData = false
} else {
- awaitProgressLock.lock()
- try {
- // Wake up any threads that are waiting for the stream to progress.
- awaitProgressLockCondition.signalAll()
- } finally {
- awaitProgressLock.unlock()
- }
+ noNewData = true
+ awaitProgressLockCondition.signalAll()
}
+ shouldConstructNextBatch
}
/**
@@ -386,6 +415,8 @@ class MicroBatchExecution(
* @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
*/
private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
+ logDebug(s"Running batch $currentBatchId")
+
// Request unprocessed data from all sources.
newData = reportTimeTaken("getBatch") {
availableOffsets.flatMap {
@@ -401,18 +432,31 @@ class MicroBatchExecution(
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
+ val availableV2: OffsetV2 = available match {
+ case v1: SerializedOffset => reader.deserializeOffset(v1.json)
+ case v2: OffsetV2 => v2
+ }
reader.setOffsetRange(
toJava(current),
- Optional.of(available.asInstanceOf[OffsetV2]))
- logDebug(s"Retrieving data from $reader: $current -> $available")
- Some(reader ->
- new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
+ Optional.of(availableV2))
+ logDebug(s"Retrieving data from $reader: $current -> $availableV2")
+
+ val (source, options) = reader match {
+ // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2`
+ // implementation. We provide a fake one here for explain.
+ case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String]
+ // Provide a fake value here just in case something went wrong, e.g. the reader gives
+ // a wrong `equals` implementation.
+ case _ => readerToDataSourceMap.getOrElse(reader, {
+ FakeDataSourceV2 -> Map.empty[String, String]
+ })
+ }
+ Some(reader -> StreamingDataSourceV2Relation(
+ reader.readSchema().toAttributes, source, options, reader))
case _ => None
}
}
- // A list of attributes that will need to be updated.
- val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val newBatchesPlan = logicalPlan transform {
case StreamingExecutionRelation(source, output) =>
@@ -420,18 +464,18 @@ class MicroBatchExecution(
assert(output.size == dataPlan.output.size,
s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
s"${Utils.truncatedString(dataPlan.output, ",")}")
- replacements ++= output.zip(dataPlan.output)
- dataPlan
+
+ val aliases = output.zip(dataPlan.output).map { case (to, from) =>
+ Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata))
+ }
+ Project(aliases, dataPlan)
}.getOrElse {
LocalRelation(output, isStreaming = true)
}
}
// Rewire the plan to use the new attributes that were returned by the source.
- val replacementMap = AttributeMap(replacements)
val newAttributePlan = newBatchesPlan transformAllExpressions {
- case a: Attribute if replacementMap.contains(a) =>
- replacementMap(a).withMetadata(a.metadata)
case ct: CurrentTimestamp =>
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
ct.dataType)
@@ -457,6 +501,9 @@ class MicroBatchExecution(
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}
+ sparkSessionToRunBatch.sparkContext.setLocalProperty(
+ MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString)
+
reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSessionToRunBatch,
@@ -483,10 +530,20 @@ class MicroBatchExecution(
}
}
+ withProgressLocked {
+ commitLog.add(currentBatchId)
+ committedOffsets ++= availableOffsets
+ awaitProgressLockCondition.signalAll()
+ }
+ watermarkTracker.updateWatermark(lastExecution.executedPlan)
+ logDebug(s"Completed batch ${currentBatchId}")
+ }
+
+ /** Execute a function while locking the stream from making an progress */
+ private[sql] def withProgressLocked[T](f: => T): T = {
awaitProgressLock.lock()
try {
- // Wake up any threads that are waiting for the stream to progress.
- awaitProgressLockCondition.signalAll()
+ f
} finally {
awaitProgressLock.unlock()
}
@@ -496,3 +553,11 @@ class MicroBatchExecution(
Optional.ofNullable(scalaOption.orNull)
}
}
+
+object MicroBatchExecution {
+ val BATCH_ID_KEY = "streaming.sql.batchId"
+}
+
+object MemoryStreamDataSource extends DataSourceV2
+
+object FakeDataSourceV2 extends DataSourceV2
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
index 80aa5505db991..43ad4b3384ec3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
@@ -19,8 +19,8 @@
/**
* This is an internal, deprecated interface. New source implementations should use the
- * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported
- * in the long term.
+ * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be
+ * supported in the long term.
*
* This class will be removed in a future release.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 73945b39b8967..787174481ff08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet
* cannot be serialized).
*/
def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = {
- assert(sources.size == offsets.size)
+ assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " +
+ s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " +
+ s"Cannot continue.")
new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index d1e5be9c12762..16ad3ef9a3d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
import org.apache.spark.util.Clock
@@ -141,7 +143,7 @@ trait ProgressReporter extends Logging {
}
logDebug(s"Execution stats: $executionStats")
- val sourceProgress = sources.map { source =>
+ val sourceProgress = sources.distinct.map { source =>
val numRecords = executionStats.inputRows.getOrElse(source, 0L)
new SourceProgress(
description = source.toString,
@@ -207,62 +209,126 @@ trait ProgressReporter extends Logging {
return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp)
}
- // We want to associate execution plan leaves to sources that generate them, so that we match
- // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following.
- // Consider the translation from the streaming logical plan to the final executed plan.
- //
- // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan
- //
- // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan
- // - Each logical plan leaf will be associated with a single streaming source.
- // - There can be multiple logical plan leaves associated with a streaming source.
- // - There can be leaves not associated with any streaming source, because they were
- // generated from a batch source (e.g. stream-batch joins)
- //
- // 2. Assuming that the executed plan has same number of leaves in the same order as that of
- // the trigger logical plan, we associate executed plan leaves with corresponding
- // streaming sources.
- //
- // 3. For each source, we sum the metrics of the associated execution plan leaves.
- //
- val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
- logicalPlan.collectLeaves().map { leaf => leaf -> source }
+ val numInputRows = extractSourceToNumInputRows()
+
+ val eventTimeStats = lastExecution.executedPlan.collect {
+ case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
+ val stats = e.eventTimeStats.value
+ Map(
+ "max" -> stats.max,
+ "min" -> stats.min,
+ "avg" -> stats.avg.toLong).mapValues(formatTimestamp)
+ }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
+
+ ExecutionStats(numInputRows, stateOperators, eventTimeStats)
+ }
+
+ /** Extract number of input sources for each streaming source in plan */
+ private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = {
+
+ import java.util.IdentityHashMap
+ import scala.collection.JavaConverters._
+
+ def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = {
+ tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source
}
- val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming
- val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
- val numInputRows: Map[BaseStreamingSource, Long] =
+
+ val onlyDataSourceV2Sources = {
+ // Check whether the streaming query's logical plan has only V2 data sources
+ val allStreamingLeaves =
+ logicalPlan.collect { case s: StreamingExecutionRelation => s }
+ allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] }
+ }
+
+ if (onlyDataSourceV2Sources) {
+ // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data
+ // from a V2 source and has a direct reference to the V2 source that generated it. Each
+ // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However,
+ // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as
+ // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or
+ // even multiple times) points and considering it twice will leads to double counting. We
+ // can't dedup them using their hashcode either because two different instances of
+ // DataSourceV2ScanExec can have the same hashcode but account for separate sets of
+ // records read, and deduping them to consider only one of them would be undercounting the
+ // records read. Therefore the right way to do this is to consider the unique instances of
+ // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them.
+ // Hence we calculate in the following way.
+ //
+ // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap.
+ //
+ // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes.
+ //
+ // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with
+ // self-unions or self-joins). Add up the number of rows for each unique source.
+ val uniqueStreamingExecLeavesMap =
+ new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]()
+
+ lastExecution.executedPlan.collectLeaves().foreach {
+ case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] =>
+ uniqueStreamingExecLeavesMap.put(s, s)
+ case _ =>
+ }
+
+ val sourceToInputRowsTuples =
+ uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf =>
+ val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
+ val source = execLeaf.reader.asInstanceOf[BaseStreamingSource]
+ source -> numRows
+ }.toSeq
+ logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))
+ sumRows(sourceToInputRowsTuples)
+ } else {
+
+ // Since V1 source do not generate execution plan leaves that directly link with source that
+ // generated it, we can only do a best-effort association between execution plan leaves to the
+ // sources. This is known to fail in a few cases, see SPARK-24050.
+ //
+ // We want to associate execution plan leaves to sources that generate them, so that we match
+ // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following.
+ // Consider the translation from the streaming logical plan to the final executed plan.
+ //
+ // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan
+ //
+ // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan
+ // - Each logical plan leaf will be associated with a single streaming source.
+ // - There can be multiple logical plan leaves associated with a streaming source.
+ // - There can be leaves not associated with any streaming source, because they were
+ // generated from a batch source (e.g. stream-batch joins)
+ //
+ // 2. Assuming that the executed plan has same number of leaves in the same order as that of
+ // the trigger logical plan, we associate executed plan leaves with corresponding
+ // streaming sources.
+ //
+ // 3. For each source, we sum the metrics of the associated execution plan leaves.
+ //
+ val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) =>
+ logicalPlan.collectLeaves().map { leaf => leaf -> source }
+ }
+ val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming
+ val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves()
if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) {
val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap {
case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source }
}
- val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) =>
+ val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) =>
val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
source -> numRows
}
- sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source
+ sumRows(sourceToInputRowsTuples)
} else {
if (!metricWarningLogged) {
def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}"
+
logWarning(
"Could not report metrics as number leaves in trigger logical plan did not match that" +
- s" of the execution plan:\n" +
- s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" +
- s"execution plan leaves: ${toString(allExecPlanLeaves)}\n")
+ s" of the execution plan:\n" +
+ s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" +
+ s"execution plan leaves: ${toString(allExecPlanLeaves)}\n")
metricWarningLogged = true
}
Map.empty
}
-
- val eventTimeStats = lastExecution.executedPlan.collect {
- case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
- val stats = e.eventTimeStats.value
- Map(
- "max" -> stats.max,
- "min" -> stats.min,
- "avg" -> stats.avg.toLong).mapValues(formatTimestamp)
- }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
-
- ExecutionStats(numInputRows, stateOperators, eventTimeStats)
+ }
}
/** Records the duration of running `body` for the next query progress update. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
deleted file mode 100644
index ce5e63f5bde85..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
+++ /dev/null
@@ -1,263 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io._
-import java.nio.charset.StandardCharsets
-import java.util.Optional
-import java.util.concurrent.TimeUnit
-
-import org.apache.commons.io.IOUtils
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
-import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
-import org.apache.spark.sql.types._
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- * A source that generates increment long values with timestamps. Each generated row has two
- * columns: a timestamp column for the generated time and an auto increment long column starting
- * with 0L.
- *
- * This source supports the following options:
- * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
- * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
- * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
- * seconds.
- * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
- * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
- * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
- */
-class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
- with DataSourceV2 with ContinuousReadSupport {
-
- override def sourceSchema(
- sqlContext: SQLContext,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): (String, StructType) = {
- if (schema.nonEmpty) {
- throw new AnalysisException("The rate source does not support a user-specified schema.")
- }
-
- (shortName(), RateSourceProvider.SCHEMA)
- }
-
- override def createSource(
- sqlContext: SQLContext,
- metadataPath: String,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): Source = {
- val params = CaseInsensitiveMap(parameters)
-
- val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
- if (rowsPerSecond <= 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
- "must be positive")
- }
-
- val rampUpTimeSeconds =
- params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
- if (rampUpTimeSeconds < 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
- "must not be negative")
- }
-
- val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
- sqlContext.sparkContext.defaultParallelism)
- if (numPartitions <= 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
- "must be positive")
- }
-
- new RateStreamSource(
- sqlContext,
- metadataPath,
- rowsPerSecond,
- rampUpTimeSeconds,
- numPartitions,
- params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
- )
- }
-
- override def createContinuousReader(
- schema: Optional[StructType],
- checkpointLocation: String,
- options: DataSourceOptions): ContinuousReader = {
- new RateStreamContinuousReader(options)
- }
-
- override def shortName(): String = "rate"
-}
-
-object RateSourceProvider {
- val SCHEMA =
- StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
-
- val VERSION = 1
-}
-
-class RateStreamSource(
- sqlContext: SQLContext,
- metadataPath: String,
- rowsPerSecond: Long,
- rampUpTimeSeconds: Long,
- numPartitions: Int,
- useManualClock: Boolean) extends Source with Logging {
-
- import RateSourceProvider._
- import RateStreamSource._
-
- val clock = if (useManualClock) new ManualClock else new SystemClock
-
- private val maxSeconds = Long.MaxValue / rowsPerSecond
-
- if (rampUpTimeSeconds > maxSeconds) {
- throw new ArithmeticException(
- s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
- s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
- }
-
- private val startTimeMs = {
- val metadataLog =
- new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
- override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
- val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
- writer.write("v" + VERSION + "\n")
- writer.write(metadata.json)
- writer.flush
- }
-
- override def deserialize(in: InputStream): LongOffset = {
- val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
- // HDFSMetadataLog guarantees that it never creates a partial file.
- assert(content.length != 0)
- if (content(0) == 'v') {
- val indexOfNewLine = content.indexOf("\n")
- if (indexOfNewLine > 0) {
- val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
- LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
- } else {
- throw new IllegalStateException(
- s"Log file was malformed: failed to detect the log file version line.")
- }
- } else {
- throw new IllegalStateException(
- s"Log file was malformed: failed to detect the log file version line.")
- }
- }
- }
-
- metadataLog.get(0).getOrElse {
- val offset = LongOffset(clock.getTimeMillis())
- metadataLog.add(0, offset)
- logInfo(s"Start time: $offset")
- offset
- }.offset
- }
-
- /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
- @volatile private var lastTimeMs = startTimeMs
-
- override def schema: StructType = RateSourceProvider.SCHEMA
-
- override def getOffset: Option[Offset] = {
- val now = clock.getTimeMillis()
- if (lastTimeMs < now) {
- lastTimeMs = now
- }
- Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
- }
-
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
- val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
- val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
- assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
- if (endSeconds > maxSeconds) {
- throw new ArithmeticException("Integer overflow. Max offset with " +
- s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
- }
- // Fix "lastTimeMs" for recovery
- if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
- lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
- }
- val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
- val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
- logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
- s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
-
- if (rangeStart == rangeEnd) {
- return sqlContext.internalCreateDataFrame(
- sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
- }
-
- val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
- val relativeMsPerValue =
- TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
-
- val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
- val relative = math.round((v - rangeStart) * relativeMsPerValue)
- InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
- }
- sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
- }
-
- override def stop(): Unit = {}
-
- override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
- s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
-}
-
-object RateStreamSource {
-
- /** Calculate the end value we will emit at the time `seconds`. */
- def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
- // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
- // Then speedDeltaPerSecond = 2
- //
- // seconds = 0 1 2 3 4 5 6
- // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
- // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
- val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
- if (seconds <= rampUpTimeSeconds) {
- // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
- // avoid overflow
- if (seconds % 2 == 1) {
- (seconds + 1) / 2 * speedDeltaPerSecond * seconds
- } else {
- seconds / 2 * speedDeltaPerSecond * (seconds + 1)
- }
- } else {
- // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
- val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
- rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index e7982d7880ceb..290de873c5cfb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -356,25 +356,7 @@ abstract class StreamExecution(
private def isInterruptedByStop(e: Throwable): Boolean = {
if (state.get == TERMINATED) {
- e match {
- // InterruptedIOException - thrown when an I/O operation is interrupted
- // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted
- case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException =>
- true
- // The cause of the following exceptions may be one of the above exceptions:
- //
- // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as
- // BiFunction.apply
- // ExecutionException - thrown by codes running in a thread pool and these codes throw an
- // exception
- // UncheckedExecutionException - thrown by codes that cannot throw a checked
- // ExecutionException, such as BiFunction.apply
- case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException)
- if e2.getCause != null =>
- isInterruptedByStop(e2.getCause)
- case _ =>
- false
- }
+ StreamExecution.isInterruptionException(e)
} else {
false
}
@@ -396,24 +378,6 @@ abstract class StreamExecution(
}
}
- /**
- * Signals to the thread executing micro-batches that it should stop running after the next
- * batch. This method blocks until the thread stops running.
- */
- override def stop(): Unit = {
- // Set the state to TERMINATED so that the batching thread knows that it was interrupted
- // intentionally
- state.set(TERMINATED)
- if (queryExecutionThread.isAlive) {
- sparkSession.sparkContext.cancelJobGroup(runId.toString)
- queryExecutionThread.interrupt()
- queryExecutionThread.join()
- // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak
- sparkSession.sparkContext.cancelJobGroup(runId.toString)
- }
- logInfo(s"Query $prettyIdString was stopped")
- }
-
/**
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
@@ -565,6 +529,26 @@ abstract class StreamExecution(
object StreamExecution {
val QUERY_ID_KEY = "sql.streaming.queryId"
+
+ def isInterruptionException(e: Throwable): Boolean = e match {
+ // InterruptedIOException - thrown when an I/O operation is interrupted
+ // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted
+ case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException =>
+ true
+ // The cause of the following exceptions may be one of the above exceptions:
+ //
+ // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as
+ // BiFunction.apply
+ // ExecutionException - thrown by codes running in a thread pool and these codes throw an
+ // exception
+ // UncheckedExecutionException - thrown by codes that cannot throw a checked
+ // ExecutionException, such as BiFunction.apply
+ case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException)
+ if e2.getCause != null =>
+ isInterruptionException(e2.getCause)
+ case _ =>
+ false
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index 845c8d2c14e43..24195b5657e8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -20,13 +20,12 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.LeafNode
-import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.sources.v2.DataSourceV2
-import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2}
object StreamingRelation {
def apply(dataSource: DataSource): StreamingRelation = {
@@ -43,7 +42,7 @@ object StreamingRelation {
* passing to [[StreamExecution]] to run a query.
*/
case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
- extends LeafNode {
+ extends LeafNode with MultiInstanceRelation {
override def isStreaming: Boolean = true
override def toString: String = sourceName
@@ -54,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes)
)
+
+ override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))
}
/**
@@ -63,8 +64,9 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
case class StreamingExecutionRelation(
source: BaseStreamingSource,
output: Seq[Attribute])(session: SparkSession)
- extends LeafNode {
+ extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = source.toString
@@ -75,6 +77,8 @@ case class StreamingExecutionRelation(
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
+
+ override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}
// We have to pack in the V1 data source as a shim, for the case when a source implements
@@ -93,13 +97,16 @@ case class StreamingRelationV2(
extraOptions: Map[String, String],
output: Seq[Attribute],
v1Relation: Option[StreamingRelation])(session: SparkSession)
- extends LeafNode {
+ extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = sourceName
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
+
+ override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}
/**
@@ -109,8 +116,9 @@ case class ContinuousExecutionRelation(
source: ContinuousReadSupport,
extraOptions: Map[String, String],
output: Seq[Attribute])(session: SparkSession)
- extends LeafNode {
+ extends LeafNode with MultiInstanceRelation {
+ override def otherCopyArgs: Seq[AnyRef] = session :: Nil
override def isStreaming: Boolean = true
override def toString: String = source.toString
@@ -121,6 +129,8 @@ case class ContinuousExecutionRelation(
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)
+
+ override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index c351f658cb955..afa664eb76525 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec(
val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length)
override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
+ ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
override def output: Seq[Attribute] = joinType match {
case _: InnerLike => left.output ++ right.output
@@ -186,6 +187,17 @@ case class StreamingSymmetricHashJoinExec(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ val watermarkUsedForStateCleanup =
+ stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty
+
+ // Latest watermark value is more than that used in this previous executed plan
+ val watermarkHasChanged =
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+
+ watermarkUsedForStateCleanup && watermarkHasChanged
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
@@ -318,8 +330,7 @@ case class StreamingSymmetricHashJoinExec(
// outer join) if possible. In all cases, nothing needs to be outputted, hence the removal
// needs to be done greedily by immediately consuming the returned iterator.
val cleanupIter = joinType match {
- case Inner =>
- leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+ case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case LeftOuter => rightSideJoiner.removeOldState()
case RightOuter => leftSideJoiner.removeOldState()
case _ => throwBadJoinTypeException()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
new file mode 100644
index 0000000000000..80865669558dd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.SparkPlan
+
+class WatermarkTracker extends Logging {
+ private val operatorToWatermarkMap = mutable.HashMap[Int, Long]()
+ private var watermarkMs: Long = 0
+ private var updated = false
+
+ def setWatermark(newWatermarkMs: Long): Unit = synchronized {
+ watermarkMs = newWatermarkMs
+ }
+
+ def updateWatermark(executedPlan: SparkPlan): Unit = synchronized {
+ val watermarkOperators = executedPlan.collect {
+ case e: EventTimeWatermarkExec => e
+ }
+ if (watermarkOperators.isEmpty) return
+
+
+ watermarkOperators.zipWithIndex.foreach {
+ case (e, index) if e.eventTimeStats.value.count > 0 =>
+ logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
+ val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs
+ val prevWatermarkMs = operatorToWatermarkMap.get(index)
+ if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) {
+ operatorToWatermarkMap.put(index, newWatermarkMs)
+ }
+
+ // Populate 0 if we haven't seen any data yet for this watermark node.
+ case (_, index) =>
+ if (!operatorToWatermarkMap.isDefinedAt(index)) {
+ operatorToWatermarkMap.put(index, 0)
+ }
+ }
+
+ // Update the global watermark to the minimum of all watermark nodes.
+ // This is the safest option, because only the global watermark is fault-tolerant. Making
+ // it the minimum of all individual watermarks guarantees it will never advance past where
+ // any individual watermark operator would be if it were in a plan by itself.
+ val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2
+ if (newWatermarkMs > watermarkMs) {
+ logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
+ watermarkMs = newWatermarkMs
+ updated = true
+ } else {
+ logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs")
+ updated = false
+ }
+ }
+
+ def currentWatermark: Long = synchronized { watermarkMs }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index db600866067bc..cfba1001c6de0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2}
-import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
new file mode 100644
index 0000000000000..a7ccce10b0cee
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset}
+import org.apache.spark.util.{NextIterator, ThreadUtils}
+
+class ContinuousDataSourceRDDPartition(
+ val index: Int,
+ val inputPartition: InputPartition[UnsafeRow])
+ extends Partition with Serializable {
+
+ // This is semantically a lazy val - it's initialized once the first time a call to
+ // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across
+ // all compute() calls for a partition. This ensures that one compute() picks up where the
+ // previous one ended.
+ // We don't make it actually a lazy val because it needs input which isn't available here.
+ // This will only be initialized on the executors.
+ private[continuous] var queueReader: ContinuousQueuedDataReader = _
+}
+
+/**
+ * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]]
+ * to read from the remote source, and polls that queue for incoming rows.
+ *
+ * Note that continuous processing calls compute() multiple times, and the same
+ * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split.
+ */
+class ContinuousDataSourceRDD(
+ sc: SparkContext,
+ dataQueueSize: Int,
+ epochPollIntervalMs: Long,
+ @transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ override protected def getPartitions: Array[Partition] = {
+ readerFactories.zipWithIndex.map {
+ case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
+ }.toArray
+ }
+
+ /**
+ * Initialize the shared reader for this partition if needed, then read rows from it until
+ * it returns null to signal the end of the epoch.
+ */
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ // If attempt number isn't 0, this is a task retry, which we don't support.
+ if (context.attemptNumber() != 0) {
+ throw new ContinuousTaskRetryException()
+ }
+
+ val readerForPartition = {
+ val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
+ if (partition.queueReader == null) {
+ partition.queueReader =
+ new ContinuousQueuedDataReader(
+ partition.inputPartition, context, dataQueueSize, epochPollIntervalMs)
+ }
+
+ partition.queueReader
+ }
+
+ new NextIterator[UnsafeRow] {
+ override def getNext(): UnsafeRow = {
+ readerForPartition.next() match {
+ case null =>
+ finished = true
+ null
+ case row => row
+ }
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations()
+ }
+}
+
+object ContinuousDataSourceRDD {
+ private[continuous] def getContinuousReader(
+ reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = {
+ reader match {
+ case r: ContinuousInputPartitionReader[UnsafeRow] => r
+ case wrapped: RowToUnsafeInputPartitionReader =>
+ wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]]
+ case _ =>
+ throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
deleted file mode 100644
index cf02c0dda25d7..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ /dev/null
@@ -1,222 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.continuous
-
-import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.AtomicBoolean
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark._
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
-import org.apache.spark.util.ThreadUtils
-
-class ContinuousDataSourceRDD(
- sc: SparkContext,
- sqlContext: SQLContext,
- @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
- extends RDD[UnsafeRow](sc, Nil) {
-
- private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
- private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
-
- override protected def getPartitions: Array[Partition] = {
- readerFactories.asScala.zipWithIndex.map {
- case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
- }.toArray
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
- // If attempt number isn't 0, this is a task retry, which we don't support.
- if (context.attemptNumber() != 0) {
- throw new ContinuousTaskRetryException()
- }
-
- val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
- .readerFactory.createDataReader()
-
- val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
-
- // This queue contains two types of messages:
- // * (null, null) representing an epoch boundary.
- // * (row, off) containing a data row and its corresponding PartitionOffset.
- val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize)
-
- val epochPollFailed = new AtomicBoolean(false)
- val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
- s"epoch-poll--$coordinatorId--${context.partitionId()}")
- val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
- epochPollExecutor.scheduleWithFixedDelay(
- epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
-
- // Important sequencing - we must get start offset before the data reader thread begins
- val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset
-
- val dataReaderFailed = new AtomicBoolean(false)
- val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed)
- dataReaderThread.setDaemon(true)
- dataReaderThread.start()
-
- context.addTaskCompletionListener(_ => {
- dataReaderThread.interrupt()
- epochPollExecutor.shutdown()
- })
-
- val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
- new Iterator[UnsafeRow] {
- private val POLL_TIMEOUT_MS = 1000
-
- private var currentEntry: (UnsafeRow, PartitionOffset) = _
- private var currentOffset: PartitionOffset = startOffset
- private var currentEpoch =
- context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- override def hasNext(): Boolean = {
- while (currentEntry == null) {
- if (context.isInterrupted() || context.isCompleted()) {
- currentEntry = (null, null)
- }
- if (dataReaderFailed.get()) {
- throw new SparkException("data read failed", dataReaderThread.failureReason)
- }
- if (epochPollFailed.get()) {
- throw new SparkException("epoch poll failed", epochPollRunnable.failureReason)
- }
- currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
- }
-
- currentEntry match {
- // epoch boundary marker
- case (null, null) =>
- epochEndpoint.send(ReportPartitionOffset(
- context.partitionId(),
- currentEpoch,
- currentOffset))
- currentEpoch += 1
- currentEntry = null
- false
- // real row
- case (_, offset) =>
- currentOffset = offset
- true
- }
- }
-
- override def next(): UnsafeRow = {
- if (currentEntry == null) throw new NoSuchElementException("No current row was set")
- val r = currentEntry._1
- currentEntry = null
- r
- }
- }
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations()
- }
-}
-
-case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset
-
-class EpochPollRunnable(
- queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
- context: TaskContext,
- failedFlag: AtomicBoolean)
- extends Thread with Logging {
- private[continuous] var failureReason: Throwable = _
-
- private val epochEndpoint = EpochCoordinatorRef.get(
- context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
- private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- override def run(): Unit = {
- try {
- val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch)
- for (i <- currentEpoch to newEpoch - 1) {
- queue.put((null, null))
- logDebug(s"Sent marker to start epoch ${i + 1}")
- }
- currentEpoch = newEpoch
- } catch {
- case t: Throwable =>
- failureReason = t
- failedFlag.set(true)
- throw t
- }
- }
-}
-
-class DataReaderThread(
- reader: DataReader[UnsafeRow],
- queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
- context: TaskContext,
- failedFlag: AtomicBoolean)
- extends Thread(
- s"continuous-reader--${context.partitionId()}--" +
- s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
- private[continuous] var failureReason: Throwable = _
-
- override def run(): Unit = {
- TaskContext.setTaskContext(context)
- val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
- try {
- while (!context.isInterrupted && !context.isCompleted()) {
- if (!reader.next()) {
- // Check again, since reader.next() might have blocked through an incoming interrupt.
- if (!context.isInterrupted && !context.isCompleted()) {
- throw new IllegalStateException(
- "Continuous reader reported no elements! Reader should have blocked waiting.")
- } else {
- return
- }
- }
-
- queue.put((reader.get().copy(), baseReader.getOffset))
- }
- } catch {
- case _: InterruptedException if context.isInterrupted() =>
- // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
-
- case t: Throwable =>
- failureReason = t
- failedFlag.set(true)
- // Don't rethrow the exception in this thread. It's not needed, and the default Spark
- // exception handler will kill the executor.
- } finally {
- reader.close()
- }
- }
-}
-
-object ContinuousDataSourceRDD {
- private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
- reader match {
- case r: ContinuousDataReader[UnsafeRow] => r
- case wrapped: RowToUnsafeDataReader =>
- wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
- case _ =>
- throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index ed22b9100497a..e3d0cea608b2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -29,12 +29,11 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2}
+import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport
+import org.apache.spark.sql.sources.v2
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
-import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{Clock, Utils}
@@ -123,16 +122,7 @@ class ContinuousExecution(
s"Batch $latestEpochId was committed without end epoch offsets!")
}
committedOffsets = nextOffsets.toStreamProgress(sources)
-
- // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
- // commit happens between offset log write and commit log write, this means an epoch ID
- // which is not in the offset log.
- val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
- throw new IllegalStateException(
- s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
- s"an element.")
- }
- currentBatchId = latestOffsetEpoch + 1
+ currentBatchId = latestEpochId + 1
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
nextOffsets
@@ -169,7 +159,7 @@ class ContinuousExecution(
var insertedSourceId = 0
val withNewSources = logicalPlan transform {
- case ContinuousExecutionRelation(_, _, output) =>
+ case ContinuousExecutionRelation(source, options, output) =>
val reader = continuousSources(insertedSourceId)
insertedSourceId += 1
val newOutput = reader.readSchema().toAttributes
@@ -182,7 +172,7 @@ class ContinuousExecution(
val loggedOffset = offsets.offsets(0)
val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json))
reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull))
- new StreamingDataSourceV2Relation(newOutput, reader)
+ StreamingDataSourceV2Relation(newOutput, source, options, reader)
}
// Rewire the plan to use the new attributes that were returned by the source.
@@ -200,10 +190,10 @@ class ContinuousExecution(
triggerLogicalPlan.schema,
outputMode,
new DataSourceOptions(extraOptions.asJava))
- val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan)
+ val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan)
val reader = withSink.collect {
- case DataSourceV2Relation(_, r: ContinuousReader) => r
+ case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
}.head
reportTimeTaken("queryPlanning") {
@@ -238,9 +228,7 @@ class ContinuousExecution(
startTrigger()
if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
- stopSources()
if (queryExecutionThread.isAlive) {
- sparkSession.sparkContext.cancelJobGroup(runId.toString)
queryExecutionThread.interrupt()
}
false
@@ -268,12 +256,20 @@ class ContinuousExecution(
SQLExecution.withNewExecutionId(
sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
}
+ } catch {
+ case t: Throwable
+ if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING =>
+ logInfo(s"Query $id ignoring exception from reconfiguring: $t")
+ // interrupted by reconfiguration - swallow exception so we can restart the query
} finally {
epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
SparkEnv.get.rpcEnv.stop(epochEndpoint)
epochUpdateThread.interrupt()
epochUpdateThread.join()
+
+ stopSources()
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
}
}
@@ -313,16 +309,23 @@ class ContinuousExecution(
synchronized {
if (queryExecutionThread.isAlive) {
commitLog.add(epoch)
- val offset = offsetLog.get(epoch).get.offsets(0).get
+ val offset =
+ continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
committedOffsets ++= Seq(continuousSources(0) -> offset)
+ continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
} else {
return
}
}
- if (minLogEntriesToMaintain < currentBatchId) {
- offsetLog.purge(currentBatchId - minLogEntriesToMaintain)
- commitLog.purge(currentBatchId - minLogEntriesToMaintain)
+ // Since currentBatchId increases independently in cp mode, the current committed epoch may
+ // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId
+ // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum
+ // number of batches that must be retained and made recoverable, so we should keep the
+ // specified number of metadata that have been committed.
+ if (minLogEntriesToMaintain <= epoch) {
+ offsetLog.purge(epoch + 1 - minLogEntriesToMaintain)
+ commitLog.purge(epoch + 1 - minLogEntriesToMaintain)
}
awaitProgressLock.lock()
@@ -358,6 +361,22 @@ class ContinuousExecution(
}
}
}
+
+ /**
+ * Stops the query execution thread to terminate the query.
+ */
+ override def stop(): Unit = {
+ // Set the state to TERMINATED so that the batching thread knows that it was interrupted
+ // intentionally
+ state.set(TERMINATED)
+ if (queryExecutionThread.isAlive) {
+ // The query execution thread will clean itself up in the finally clause of runContinuous.
+ // We just need to interrupt the long running job.
+ queryExecutionThread.interrupt()
+ queryExecutionThread.join()
+ }
+ logInfo(s"Query $prettyIdString was stopped")
+ }
}
object ContinuousExecution {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
new file mode 100644
index 0000000000000..f38577b6a9f16
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.io.Closeable
+import java.util.concurrent.{ArrayBlockingQueue, TimeUnit}
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A wrapper for a continuous processing data reader, including a reading queue and epoch markers.
+ *
+ * This will be instantiated once per partition - successive calls to compute() in the
+ * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of
+ * offsets across epochs. Each compute() should call the next() method here until null is returned.
+ */
+class ContinuousQueuedDataReader(
+ partition: InputPartition[UnsafeRow],
+ context: TaskContext,
+ dataQueueSize: Int,
+ epochPollIntervalMs: Long) extends Closeable {
+ private val reader = partition.createPartitionReader()
+
+ // Important sequencing - we must get our starting point before the provider threads start running
+ private var currentOffset: PartitionOffset =
+ ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
+
+ /**
+ * The record types in the read buffer.
+ */
+ sealed trait ContinuousRecord
+ case object EpochMarker extends ContinuousRecord
+ case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord
+
+ private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize)
+
+ private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
+ private val epochCoordEndpoint = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+
+ private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ s"epoch-poll--$coordinatorId--${context.partitionId()}")
+ private val epochMarkerGenerator = new EpochMarkerGenerator
+ epochMarkerExecutor.scheduleWithFixedDelay(
+ epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
+
+ private val dataReaderThread = new DataReaderThread
+ dataReaderThread.setDaemon(true)
+ dataReaderThread.start()
+
+ context.addTaskCompletionListener(_ => {
+ this.close()
+ })
+
+ private def shouldStop() = {
+ context.isInterrupted() || context.isCompleted()
+ }
+
+ /**
+ * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done.
+ *
+ * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch
+ * will call next() again to start getting rows.
+ */
+ def next(): UnsafeRow = {
+ val POLL_TIMEOUT_MS = 1000
+ var currentEntry: ContinuousRecord = null
+
+ while (currentEntry == null) {
+ if (shouldStop()) {
+ // Force the epoch to end here. The writer will notice the context is interrupted
+ // or completed and not start a new one. This makes it possible to achieve clean
+ // shutdown of the streaming query.
+ // TODO: The obvious generalization of this logic to multiple stages won't work. It's
+ // invalid to send an epoch marker from the bottom of a task if all its child tasks
+ // haven't sent one.
+ currentEntry = EpochMarker
+ } else {
+ if (dataReaderThread.failureReason != null) {
+ throw new SparkException("Data read failed", dataReaderThread.failureReason)
+ }
+ if (epochMarkerGenerator.failureReason != null) {
+ throw new SparkException(
+ "Epoch marker generation failed",
+ epochMarkerGenerator.failureReason)
+ }
+ currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }
+ }
+
+ currentEntry match {
+ case EpochMarker =>
+ epochCoordEndpoint.send(ReportPartitionOffset(
+ context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset))
+ null
+ case ContinuousRow(row, offset) =>
+ currentOffset = offset
+ row
+ }
+ }
+
+ override def close(): Unit = {
+ dataReaderThread.interrupt()
+ epochMarkerExecutor.shutdown()
+ }
+
+ /**
+ * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when
+ * a new row arrives to the [[InputPartitionReader]].
+ */
+ class DataReaderThread extends Thread(
+ s"continuous-reader--${context.partitionId()}--" +
+ s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging {
+ @volatile private[continuous] var failureReason: Throwable = _
+
+ override def run(): Unit = {
+ TaskContext.setTaskContext(context)
+ val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader)
+ try {
+ while (!shouldStop()) {
+ if (!reader.next()) {
+ // Check again, since reader.next() might have blocked through an incoming interrupt.
+ if (!shouldStop()) {
+ throw new IllegalStateException(
+ "Continuous reader reported no elements! Reader should have blocked waiting.")
+ } else {
+ return
+ }
+ }
+
+ queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset))
+ }
+ } catch {
+ case _: InterruptedException =>
+ // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
+ logInfo(s"shutting down interrupted data reader thread $getName")
+
+ case NonFatal(t) =>
+ failureReason = t
+ logWarning("data reader thread failed", t)
+ // If we throw from this thread, we may kill the executor. Let the parent thread handle
+ // it.
+
+ case t: Throwable =>
+ failureReason = t
+ throw t
+ } finally {
+ reader.close()
+ }
+ }
+ }
+
+ /**
+ * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with
+ * EpochMarker when a new epoch marker arrives.
+ */
+ class EpochMarkerGenerator extends Runnable with Logging {
+ @volatile private[continuous] var failureReason: Throwable = _
+
+ private val epochCoordEndpoint = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+ // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That
+ // field represents the epoch wrt the data being processed. The currentEpoch here is just a
+ // counter to ensure we send the appropriate number of markers if we fall behind the driver.
+ private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ override def run(): Unit = {
+ try {
+ val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch)
+ // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking
+ // a while. We catch up by injecting enough epoch markers immediately to catch up. This will
+ // result in some epochs being empty for this partition, but that's fine.
+ for (i <- currentEpoch to newEpoch - 1) {
+ queue.put(EpochMarker)
+ logDebug(s"Sent marker to start epoch ${i + 1}")
+ }
+ currentEpoch = newEpoch
+ } catch {
+ case t: Throwable =>
+ failureReason = t
+ throw t
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index b63d8d3e20650..516a563bdcc7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -24,11 +24,11 @@ import org.json4s.jackson.Serialization
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
case class RateStreamPartitionOffset(
@@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
val creationTime = System.currentTimeMillis()
- val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
- val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
+ val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
+ val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
@@ -57,17 +57,17 @@ class RateStreamContinuousReader(options: DataSourceOptions)
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
}
- override def readSchema(): StructType = RateSourceProvider.SCHEMA
+ override def readSchema(): StructType = RateStreamProvider.SCHEMA
private var offset: Offset = _
override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
- this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
+ this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
}
override def getStartOffset(): Offset = offset
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ override def planInputPartitions(): java.util.List[InputPartition[Row]] = {
val partitionStartMap = offset match {
case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
case off =>
@@ -85,40 +85,66 @@ class RateStreamContinuousReader(options: DataSourceOptions)
val start = partitionStartMap(i)
// Have each partition advance by numPartitions each row, with starting points staggered
// by their partition index.
- RateStreamContinuousDataReaderFactory(
+ RateStreamContinuousInputPartition(
start.value,
start.runTimeMs,
i,
numPartitions,
perPartitionRate)
- .asInstanceOf[DataReaderFactory[Row]]
+ .asInstanceOf[InputPartition[Row]]
}.asJava
}
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
+ private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
+ RateStreamOffset(
+ Range(0, numPartitions).map { i =>
+ // Note that the starting offset is exclusive, so we have to decrement the starting value
+ // by the increment that will later be applied. The first row output in each
+ // partition will have a value equal to the partition index.
+ (i,
+ ValueRunTimeMsPair(
+ (i - numPartitions).toLong,
+ creationTimeMs))
+ }.toMap)
+ }
+
}
-case class RateStreamContinuousDataReaderFactory(
+case class RateStreamContinuousInputPartition(
startValue: Long,
startTimeMs: Long,
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
- extends DataReaderFactory[Row] {
- override def createDataReader(): DataReader[Row] =
- new RateStreamContinuousDataReader(
+ extends ContinuousInputPartition[Row] {
+
+ override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = {
+ val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
+ require(rateStreamOffset.partition == partitionIndex,
+ s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
+ new RateStreamContinuousInputPartitionReader(
+ rateStreamOffset.currentValue,
+ rateStreamOffset.currentTimeMs,
+ partitionIndex,
+ increment,
+ rowsPerSecond)
+ }
+
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new RateStreamContinuousInputPartitionReader(
startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
}
-class RateStreamContinuousDataReader(
+class RateStreamContinuousInputPartitionReader(
startValue: Long,
startTimeMs: Long,
partitionIndex: Int,
increment: Long,
rowsPerSecond: Double)
- extends ContinuousDataReader[Row] {
+ extends ContinuousInputPartitionReader[Row] {
private var nextReadTime: Long = startTimeMs
private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
new file mode 100644
index 0000000000000..ef5f0da1e7cc2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.util.Utils
+
+/**
+ * The RDD writing to a sink in continuous processing.
+ *
+ * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data
+ * to be written for one epoch, which we commit and forward to the driver.
+ *
+ * We keep repeating prev.compute() and writing new epochs until the query is shut down.
+ */
+class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
+ extends RDD[Unit](prev) {
+
+ override val partitioner = prev.partitioner
+
+ override def getPartitions: Array[Partition] = prev.partitions
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Unit] = {
+ val epochCoordinator = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+ SparkEnv.get)
+ EpochTracker.initializeCurrentEpoch(
+ context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
+
+ while (!context.isInterrupted() && !context.isCompleted()) {
+ var dataWriter: DataWriter[InternalRow] = null
+ // write the data and commit this writer.
+ Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+ try {
+ val dataIterator = prev.compute(split, context)
+ dataWriter = writeTask.createDataWriter(
+ context.partitionId(),
+ context.attemptNumber(),
+ EpochTracker.getCurrentEpoch.get)
+ while (dataIterator.hasNext) {
+ dataWriter.write(dataIterator.next())
+ }
+ logInfo(s"Writer for partition ${context.partitionId()} " +
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
+ val msg = dataWriter.commit()
+ epochCoordinator.send(
+ CommitPartitionEpoch(
+ context.partitionId(),
+ EpochTracker.getCurrentEpoch.get,
+ msg)
+ )
+ logInfo(s"Writer for partition ${context.partitionId()} " +
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.")
+ EpochTracker.incrementCurrentEpoch()
+ } catch {
+ case _: InterruptedException =>
+ // Continuous shutdown always involves an interrupt. Just finish the task.
+ }
+ })(catchBlock = {
+ // If there is an error, abort this writer. We enter this callback in the middle of
+ // rethrowing an exception, so compute() will stop executing at this point.
+ logError(s"Writer for partition ${context.partitionId()} is aborting.")
+ if (dataWriter != null) dataWriter.abort()
+ logError(s"Writer for partition ${context.partitionId()} aborted.")
+ })
+ }
+
+ Iterator()
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index cc6808065c0cd..8877ebeb26735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator(
private val partitionOffsets =
mutable.Map[(Long, Int), PartitionOffset]()
+ private var lastCommittedEpoch = startEpoch - 1
+ // Remembers epochs that have to wait for previous epochs to be committed first.
+ private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long]
+
private def resolveCommitsAtEpoch(epoch: Long) = {
- val thisEpochCommits =
- partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
+ val thisEpochCommits = findPartitionCommitsForEpoch(epoch)
val nextEpochOffsets =
partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
if (thisEpochCommits.size == numWriterPartitions &&
nextEpochOffsets.size == numReaderPartitions) {
- logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.")
- // Sequencing is important here. We must commit to the writer before recording the commit
- // in the query, or we will end up dropping the commit if we restart in the middle.
- writer.commit(epoch, thisEpochCommits.toArray)
- query.commit(epoch)
-
- // Cleanup state from before this epoch, now that we know all partitions are forever past it.
- for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) {
- partitionCommits.remove(k)
- }
- for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
- partitionOffsets.remove(k)
+
+ // Check that last committed epoch is the previous one for sequencing of committed epochs.
+ // If not, add the epoch being currently processed to epochs waiting to be committed,
+ // otherwise commit it.
+ if (lastCommittedEpoch != epoch - 1) {
+ logDebug(s"Epoch $epoch has received commits from all partitions " +
+ s"and is waiting for epoch ${epoch - 1} to be committed first.")
+ epochsWaitingToBeCommitted.add(epoch)
+ } else {
+ commitEpoch(epoch, thisEpochCommits)
+ lastCommittedEpoch = epoch
+
+ // Commit subsequent epochs that are waiting to be committed.
+ var nextEpoch = lastCommittedEpoch + 1
+ while (epochsWaitingToBeCommitted.contains(nextEpoch)) {
+ val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch)
+ commitEpoch(nextEpoch, nextEpochCommits)
+
+ epochsWaitingToBeCommitted.remove(nextEpoch)
+ lastCommittedEpoch = nextEpoch
+ nextEpoch += 1
+ }
+
+ // Cleanup state from before last committed epoch,
+ // now that we know all partitions are forever past it.
+ for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) {
+ partitionCommits.remove(k)
+ }
+ for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) {
+ partitionOffsets.remove(k)
+ }
}
}
}
+ /**
+ * Collect per-partition commits for an epoch.
+ */
+ private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = {
+ partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
+ }
+
+ /**
+ * Commit epoch to the offset log.
+ */
+ private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = {
+ logDebug(s"Epoch $epoch has received commits from all partitions " +
+ s"and is ready to be committed. Committing epoch $epoch.")
+ // Sequencing is important here. We must commit to the writer before recording the commit
+ // in the query, or we will end up dropping the commit if we restart in the middle.
+ writer.commit(epoch, messages.toArray)
+ query.commit(epoch)
+ }
+
override def receive: PartialFunction[Any, Unit] = {
// If we just drop these messages, we won't do any writes to the query. The lame duck tasks
// won't shed errors or anything.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
new file mode 100644
index 0000000000000..bc0ae428d4521
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+/**
+ * Tracks the current continuous processing epoch within a task. Call
+ * EpochTracker.getCurrentEpoch to get the current epoch.
+ */
+object EpochTracker {
+ // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will
+ // update the underlying AtomicLong as it finishes epochs. Other code should only read the value.
+ private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] {
+ override def initialValue() = new AtomicLong(-1)
+ }
+
+ /**
+ * Get the current epoch for the current task, or None if the task has no current epoch.
+ */
+ def getCurrentEpoch: Option[Long] = {
+ currentEpoch.get().get() match {
+ case n if n < 0 => None
+ case e => Some(e)
+ }
+ }
+
+ /**
+ * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * between epochs.
+ */
+ def incrementCurrentEpoch(): Unit = {
+ currentEpoch.get().incrementAndGet()
+ }
+
+ /**
+ * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * at the beginning of a task.
+ */
+ def initializeCurrentEpoch(startEpoch: Long): Unit = {
+ currentEpoch.get().set(startEpoch)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
new file mode 100644
index 0000000000000..943c731a70529
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+
+/**
+ * The logical plan for writing data in a continuous stream.
+ */
+case class WriteToContinuousDataSource(
+ writer: StreamWriter, query: LogicalPlan) extends LogicalPlan {
+ override def children: Seq[LogicalPlan] = Seq(query)
+ override def output: Seq[Attribute] = Nil
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
new file mode 100644
index 0000000000000..e0af3a2f1b85d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
+import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.util.Utils
+
+/**
+ * The physical plan for writing data into a continuous processing [[StreamWriter]].
+ */
+case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan)
+ extends SparkPlan with Logging {
+ override def children: Seq[SparkPlan] = Seq(query)
+ override def output: Seq[Attribute] = Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val writerFactory = writer match {
+ case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
+ case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
+ }
+
+ val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
+
+ logInfo(s"Start processing data source writer: $writer. " +
+ s"The input RDD has ${rdd.partitions.length} partitions.")
+ EpochCoordinatorRef.get(
+ sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+ sparkContext.env)
+ .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
+
+ try {
+ // Force the RDD to run so continuous processing starts; no data is actually being collected
+ // to the driver, as ContinuousWriteRDD outputs nothing.
+ rdd.collect()
+ } catch {
+ case _: InterruptedException =>
+ // Interruption is how continuous queries are ended, so accept and ignore the exception.
+ case cause: Throwable =>
+ cause match {
+ // Do not wrap interruption exceptions that will be handled by streaming specially.
+ case _ if StreamExecution.isInterruptionException(cause) => throw cause
+ // Only wrap non fatal exceptions.
+ case NonFatal(e) => throw new SparkException("Writing job aborted.", e)
+ case _ => throw cause
+ }
+ }
+
+ sparkContext.emptyRDD
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
new file mode 100644
index 0000000000000..cf6572d3de1f7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.UUID
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.NextIterator
+
+case class ContinuousShuffleReadPartition(
+ index: Int,
+ queueSize: Int,
+ numShuffleWriters: Int,
+ epochIntervalMs: Long)
+ extends Partition {
+ // Initialized only on the executor, and only once even as we call compute() multiple times.
+ lazy val (reader: ContinuousShuffleReader, endpoint) = {
+ val env = SparkEnv.get.rpcEnv
+ val receiver = new RPCContinuousShuffleReader(
+ queueSize, numShuffleWriters, epochIntervalMs, env)
+ val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
+
+ TaskContext.get().addTaskCompletionListener { ctx =>
+ env.stop(endpoint)
+ }
+ (receiver, endpoint)
+ }
+}
+
+/**
+ * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their
+ * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks
+ * poll from their receiver until an epoch marker is sent.
+ *
+ * @param sc the RDD context
+ * @param numPartitions the number of read partitions for this RDD
+ * @param queueSize the size of the row buffers to use
+ * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD
+ * @param epochIntervalMs the checkpoint interval of the streaming query
+ */
+class ContinuousShuffleReadRDD(
+ sc: SparkContext,
+ numPartitions: Int,
+ queueSize: Int = 1024,
+ numShuffleWriters: Int = 1,
+ epochIntervalMs: Long = 1000)
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ override protected def getPartitions: Array[Partition] = {
+ (0 until numPartitions).map { partIndex =>
+ ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
new file mode 100644
index 0000000000000..42631c90ebc55
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for reading from a continuous processing shuffle.
+ */
+trait ContinuousShuffleReader {
+ /**
+ * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting
+ * for new rows to arrive, and end the iterator once they've received epoch markers from all
+ * shuffle writers.
+ */
+ def read(): Iterator[UnsafeRow]
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
similarity index 75%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
index 0372ad5270951..47b1f78b24505 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
@@ -14,12 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.submit.steps.initcontainer
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
/**
- * Represents a step in configuring the driver init-container.
+ * Trait for writing to a continuous processing shuffle.
*/
-private[spark] trait InitContainerConfigurationStep {
-
- def configureInitContainer(spec: InitContainerSpec): InitContainerSpec
+trait ContinuousShuffleWriter {
+ def write(epoch: Iterator[UnsafeRow]): Unit
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
new file mode 100644
index 0000000000000..834e84675c7d5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.NextIterator
+
+/**
+ * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker.
+ *
+ * Each message comes tagged with writerId, identifying which writer the message is coming
+ * from. The receiver will only begin the next epoch once all writers have sent an epoch
+ * marker ending the current epoch.
+ */
+private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable {
+ def writerId: Int
+}
+private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
+ extends RPCContinuousShuffleMessage
+private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage
+
+/**
+ * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
+ * writers will send rows here, with continuous shuffle readers polling for new rows as needed.
+ *
+ * TODO: Support multiple source tasks. We need to output a single epoch marker once all
+ * source tasks have sent one.
+ */
+private[shuffle] class RPCContinuousShuffleReader(
+ queueSize: Int,
+ numShuffleWriters: Int,
+ epochIntervalMs: Long,
+ override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
+ // Note that this queue will be drained from the main task thread and populated in the RPC
+ // response thread.
+ private val queues = Array.fill(numShuffleWriters) {
+ new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize)
+ }
+
+ // Exposed for testing to determine if the endpoint gets stopped on task end.
+ private[shuffle] val stopped = new AtomicBoolean(false)
+
+ override def onStop(): Unit = {
+ stopped.set(true)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case r: RPCContinuousShuffleMessage =>
+ // Note that this will block a thread the shared RPC handler pool!
+ // The TCP based shuffle handler (SPARK-24541) will avoid this problem.
+ queues(r.writerId).put(r)
+ context.reply(())
+ }
+
+ override def read(): Iterator[UnsafeRow] = {
+ new NextIterator[UnsafeRow] {
+ // An array of flags for whether each writer ID has gotten an epoch marker.
+ private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false)
+
+ private val executor = Executors.newFixedThreadPool(numShuffleWriters)
+ private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor)
+
+ private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] {
+ override def call(): RPCContinuousShuffleMessage = queues(writerId).take()
+ }
+
+ // Initialize by submitting tasks to read the first row from each writer.
+ (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId)))
+
+ /**
+ * In each call to getNext(), we pull the next row available in the completion queue, and then
+ * submit another task to read the next row from the writer which returned it.
+ *
+ * When a writer sends an epoch marker, we note that it's finished and don't submit another
+ * task for it in this epoch. The iterator is over once all writers have sent an epoch marker.
+ */
+ override def getNext(): UnsafeRow = {
+ var nextRow: UnsafeRow = null
+ while (!finished && nextRow == null) {
+ completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match {
+ case null =>
+ // Try again if the poll didn't wait long enough to get a real result.
+ // But we should be getting at least an epoch marker every checkpoint interval.
+ val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect {
+ case (flag, idx) if !flag => idx
+ }
+ logWarning(
+ s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " +
+ s"for writers $writerIdsUncommitted to send epoch markers.")
+
+ // The completion service guarantees this future will be available immediately.
+ case future => future.get() match {
+ case ReceiverRow(writerId, r) =>
+ // Start reading the next element in the queue we just took from.
+ completion.submit(completionTask(writerId))
+ nextRow = r
+ case ReceiverEpochMarker(writerId) =>
+ // Don't read any more from this queue. If all the writers have sent epoch markers,
+ // the epoch is over; otherwise we need to loop again to poll from the remaining
+ // writers.
+ writerEpochMarkersReceived(writerId) = true
+ if (writerEpochMarkersReceived.forall(_ == true)) {
+ finished = true
+ }
+ }
+ }
+ }
+
+ nextRow
+ }
+
+ override def close(): Unit = {
+ executor.shutdownNow()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
new file mode 100644
index 0000000000000..1c6f3ddb395e6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.Partitioner
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances.
+ *
+ * @param writerId The partition ID of this writer.
+ * @param outputPartitioner The partitioner on the reader side of the shuffle.
+ * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by
+ * partition ID within outputPartitioner.
+ */
+class RPCContinuousShuffleWriter(
+ writerId: Int,
+ outputPartitioner: Partitioner,
+ endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {
+
+ if (outputPartitioner.numPartitions != 1) {
+ throw new IllegalArgumentException("multiple readers not yet supported")
+ }
+
+ if (outputPartitioner.numPartitions != endpoints.length) {
+ throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " +
+ s"not match endpoint count ${endpoints.length}")
+ }
+
+ def write(epoch: Iterator[UnsafeRow]): Unit = {
+ while (epoch.hasNext) {
+ val row = epoch.next()
+ endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row))
+ }
+
+ val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq
+ implicit val ec = ThreadUtils.sameThread
+ ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 509a69dd922fb..7fa13c4aa2c01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -17,26 +17,29 @@
package org.apache.spark.sql.execution.streaming
+import java.{util => ju}
+import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
@@ -45,15 +48,43 @@ object MemoryStream {
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
}
+/**
+ * A base class for memory stream implementations. Supports adding data and resetting.
+ */
+abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource {
+ protected val encoder = encoderFor[A]
+ protected val attributes = encoder.schema.toAttributes
+
+ def toDS(): Dataset[A] = {
+ Dataset[A](sqlContext.sparkSession, logicalPlan)
+ }
+
+ def toDF(): DataFrame = {
+ Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
+ }
+
+ def addData(data: A*): Offset = {
+ addData(data.toTraversable)
+ }
+
+ def readSchema(): StructType = encoder.schema
+
+ protected def logicalPlan: LogicalPlan
+
+ def addData(data: TraversableOnce[A]): Offset
+}
+
/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
* is intended for use in unit tests as it can only replay data when the object is still
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
- extends Source with Logging {
- protected val encoder = encoderFor[A]
- protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
+ extends MemoryStreamBase[A](sqlContext)
+ with MicroBatchReader with SupportsScanUnsafeRow with Logging {
+
+ protected val logicalPlan: LogicalPlan =
+ StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output
/**
@@ -61,11 +92,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
- protected val batches = new ListBuffer[Dataset[A]]
+ protected val batches = new ListBuffer[Array[UnsafeRow]]
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
+ @GuardedBy("this")
+ protected var startOffset = new LongOffset(-1)
+
+ @GuardedBy("this")
+ private var endOffset = new LongOffset(-1)
+
/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
@@ -73,87 +110,68 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
- def schema: StructType = encoder.schema
-
- def toDS(): Dataset[A] = {
- Dataset(sqlContext.sparkSession, logicalPlan)
- }
-
- def toDF(): DataFrame = {
- Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
- }
-
- def addData(data: A*): Offset = {
- addData(data.toTraversable)
- }
-
def addData(data: TraversableOnce[A]): Offset = {
- val encoded = data.toVector.map(d => encoder.toRow(d).copy())
- val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
- val ds = Dataset[A](sqlContext.sparkSession, plan)
- logDebug(s"Adding ds: $ds")
+ val objects = data.toSeq
+ val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
+ logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
- batches += ds
+ batches += rows
currentOffset
}
}
override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
- override def getOffset: Option[Offset] = synchronized {
- if (currentOffset.offset == -1) {
- None
- } else {
- Some(currentOffset)
+ override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
+ synchronized {
+ startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
+ endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
}
}
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
- // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
- val startOrdinal =
- start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
- val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
-
- // Internal buffer only holds the batches after lastCommittedOffset.
- val newBlocks = synchronized {
- val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
- val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
- assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
- batches.slice(sliceStart, sliceEnd)
- }
+ override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
- if (newBlocks.isEmpty) {
- return sqlContext.internalCreateDataFrame(
- sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
- }
+ override def getStartOffset: OffsetV2 = synchronized {
+ if (startOffset.offset == -1) null else startOffset
+ }
- logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
+ override def getEndOffset: OffsetV2 = synchronized {
+ if (endOffset.offset == -1) null else endOffset
+ }
- newBlocks
- .map(_.toDF())
- .reduceOption(_ union _)
- .getOrElse {
- sys.error("No data selected!")
+ override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = {
+ synchronized {
+ // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
+ val startOrdinal = startOffset.offset.toInt + 1
+ val endOrdinal = endOffset.offset.toInt + 1
+
+ // Internal buffer only holds the batches after lastCommittedOffset.
+ val newBlocks = synchronized {
+ val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
+ val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
+ assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
+ batches.slice(sliceStart, sliceEnd)
}
+
+ logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
+
+ newBlocks.map { block =>
+ new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]]
+ }.asJava
+ }
}
private def generateDebugString(
- blocks: TraversableOnce[Dataset[A]],
+ rows: Seq[UnsafeRow],
startOrdinal: Int,
endOrdinal: Int): String = {
- val originalUnsupportedCheck =
- sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
- try {
- sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
- s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
- s"${blocks.flatMap(_.collect()).mkString(", ")}"
- } finally {
- sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
- }
+ val fromRow = encoder.resolveAndBind().fromRow _
+ s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
+ s"${rows.map(row => fromRow(row)).mkString(", ")}"
}
- override def commit(end: Offset): Unit = synchronized {
+ override def commit(end: OffsetV2): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
@@ -176,16 +194,88 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def reset(): Unit = synchronized {
batches.clear()
+ startOffset = LongOffset(-1)
+ endOffset = LongOffset(-1)
currentOffset = new LongOffset(-1)
lastOffsetCommitted = new LongOffset(-1)
}
}
+
+class MemoryStreamInputPartition(records: Array[UnsafeRow])
+ extends InputPartition[UnsafeRow] {
+ override def createPartitionReader(): InputPartitionReader[UnsafeRow] = {
+ new InputPartitionReader[UnsafeRow] {
+ private var currentIndex = -1
+
+ override def next(): Boolean = {
+ // Return true as long as the new index is in the array.
+ currentIndex += 1
+ currentIndex < records.length
+ }
+
+ override def get(): UnsafeRow = records(currentIndex)
+
+ override def close(): Unit = {}
+ }
+ }
+}
+
+/** A common trait for MemorySinks with methods used for testing */
+trait MemorySinkBase extends BaseStreamingSink with Logging {
+ def allData: Seq[Row]
+ def latestBatchData: Seq[Row]
+ def dataSinceBatch(sinceBatchId: Long): Seq[Row]
+ def latestBatchId: Option[Long]
+
+ /**
+ * Truncates the given rows to return at most maxRows rows.
+ * @param rows The data that may need to be truncated.
+ * @param batchLimit Number of rows to keep in this batch; the rest will be truncated
+ * @param sinkLimit Total number of rows kept in this sink, for logging purposes.
+ * @param batchId The ID of the batch that sent these rows, for logging purposes.
+ * @return Truncated rows.
+ */
+ protected def truncateRowsIfNeeded(
+ rows: Array[Row],
+ batchLimit: Int,
+ sinkLimit: Int,
+ batchId: Long): Array[Row] = {
+ if (rows.length > batchLimit && batchLimit >= 0) {
+ logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
+ rows.take(batchLimit)
+ } else {
+ rows
+ }
+ }
+}
+
+/**
+ * Companion object to MemorySinkBase.
+ */
+object MemorySinkBase {
+ val MAX_MEMORY_SINK_ROWS = "maxRows"
+ val MAX_MEMORY_SINK_ROWS_DEFAULT = -1
+
+ /**
+ * Gets the max number of rows a MemorySink should store. This number is based on the memory
+ * sink row limit option if it is set. If not, we use a large value so that data truncates
+ * rather than causing out of memory errors.
+ * @param options Options for writing from which we get the max rows option
+ * @return The maximum number of rows a memorySink should store.
+ */
+ def getMemorySinkCapacity(options: DataSourceOptions): Int = {
+ val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
+ if (maxRows >= 0) maxRows else Int.MaxValue - 10
+ }
+}
+
/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging {
+class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
+ extends Sink with MemorySinkBase with Logging {
private case class AddedData(batchId: Long, data: Array[Row])
@@ -193,9 +283,15 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()
+ /** The number of rows in this MemorySink. */
+ private var numRows = 0
+
+ /** The capacity in rows of this sink. */
+ val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
+
/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
- batches.map(_.data).flatten
+ batches.flatMap(_.data)
}
def latestBatchId: Option[Long] = synchronized {
@@ -204,6 +300,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) }
+ def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
+ batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
+ }
+
def toDebugString: String = synchronized {
batches.map { case AddedData(batchId, data) =>
val dataStr = try data.mkString(" ") catch {
@@ -221,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
- val rows = AddedData(batchId, data.collect())
- synchronized { batches += rows }
+ var rowsToAdd = data.collect()
+ synchronized {
+ rowsToAdd =
+ truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
+ val rows = AddedData(batchId, rowsToAdd)
+ batches += rows
+ numRows += rowsToAdd.length
+ }
case Complete =>
- val rows = AddedData(batchId, data.collect())
+ var rowsToAdd = data.collect()
synchronized {
+ rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
+ val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
+ numRows = rowsToAdd.length
}
case _ =>
@@ -242,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
def clear(): Unit = synchronized {
batches.clear()
+ numRows = 0
}
override def toString(): String = "MemorySink"
@@ -253,7 +363,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
- private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
+ private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
deleted file mode 100644
index 0b22cbc46e6bf..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.net.Socket
-import java.sql.Timestamp
-import java.text.SimpleDateFormat
-import java.util.{Calendar, Locale}
-import javax.annotation.concurrent.GuardedBy
-
-import scala.collection.mutable.ListBuffer
-import scala.util.{Failure, Success, Try}
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
-import org.apache.spark.unsafe.types.UTF8String
-
-
-object TextSocketSource {
- val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
- val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
- StructField("timestamp", TimestampType) :: Nil)
- val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
-}
-
-/**
- * A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
- * This source will *not* work in production applications due to multiple reasons, including no
- * support for fault recovery and keeping all of the text read in memory forever.
- */
-class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext)
- extends Source with Logging {
-
- @GuardedBy("this")
- private var socket: Socket = null
-
- @GuardedBy("this")
- private var readThread: Thread = null
-
- /**
- * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
- * Stored in a ListBuffer to facilitate removing committed batches.
- */
- @GuardedBy("this")
- protected val batches = new ListBuffer[(String, Timestamp)]
-
- @GuardedBy("this")
- protected var currentOffset: LongOffset = new LongOffset(-1)
-
- @GuardedBy("this")
- protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
-
- initialize()
-
- private def initialize(): Unit = synchronized {
- socket = new Socket(host, port)
- val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
- readThread = new Thread(s"TextSocketSource($host, $port)") {
- setDaemon(true)
-
- override def run(): Unit = {
- try {
- while (true) {
- val line = reader.readLine()
- if (line == null) {
- // End of file reached
- logWarning(s"Stream closed by $host:$port")
- return
- }
- TextSocketSource.this.synchronized {
- val newData = (line,
- Timestamp.valueOf(
- TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime()))
- )
- currentOffset = currentOffset + 1
- batches.append(newData)
- }
- }
- } catch {
- case e: IOException =>
- }
- }
- }
- readThread.start()
- }
-
- /** Returns the schema of the data from this source */
- override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP
- else TextSocketSource.SCHEMA_REGULAR
-
- override def getOffset: Option[Offset] = synchronized {
- if (currentOffset.offset == -1) {
- None
- } else {
- Some(currentOffset)
- }
- }
-
- /** Returns the data that is between the offsets (`start`, `end`]. */
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
- val startOrdinal =
- start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
- val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
-
- // Internal buffer only holds the batches after lastOffsetCommitted
- val rawList = synchronized {
- val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
- val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
- batches.slice(sliceStart, sliceEnd)
- }
-
- val rdd = sqlContext.sparkContext
- .parallelize(rawList)
- .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
- sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
- }
-
- override def commit(end: Offset): Unit = synchronized {
- val newOffset = LongOffset.convert(end).getOrElse(
- sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
- s"originate with an instance of this class")
- )
-
- val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
-
- if (offsetDiff < 0) {
- sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
- }
-
- batches.trimStart(offsetDiff)
- lastOffsetCommitted = newOffset
- }
-
- /** Stop this source. */
- override def stop(): Unit = synchronized {
- if (socket != null) {
- try {
- // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to
- // stop the readThread is to close the socket.
- socket.close()
- } catch {
- case e: IOException =>
- }
- socket = null
- }
- }
-
- override def toString: String = s"TextSocketSource[host: $host, port: $port]"
-}
-
-class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {
- private def parseIncludeTimestamp(params: Map[String, String]): Boolean = {
- Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
- case Success(bool) => bool
- case Failure(_) =>
- throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
- }
- }
-
- /** Returns the name and schema of the source that can be used to continually read data. */
- override def sourceSchema(
- sqlContext: SQLContext,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): (String, StructType) = {
- logWarning("The socket source should not be used for production applications! " +
- "It does not support recovery.")
- if (!parameters.contains("host")) {
- throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
- }
- if (!parameters.contains("port")) {
- throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
- }
- if (schema.nonEmpty) {
- throw new AnalysisException("The socket source does not support a user-specified schema.")
- }
-
- val sourceSchema =
- if (parseIncludeTimestamp(parameters)) {
- TextSocketSource.SCHEMA_TIMESTAMP
- } else {
- TextSocketSource.SCHEMA_REGULAR
- }
- ("textSocket", sourceSchema)
- }
-
- override def createSource(
- sqlContext: SQLContext,
- metadataPath: String,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): Source = {
- val host = parameters("host")
- val port = parameters("port").toInt
- new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext)
- }
-
- /** String that represents the format that this data source provider uses. */
- override def shortName(): String = "socket"
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
new file mode 100644
index 0000000000000..d1c3498450096
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.{util => ju}
+import java.util.Optional
+import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+
+import org.json4s.NoTypeHints
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.{Encoder, Row, SQLContext}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
+import org.apache.spark.sql.sources.v2.reader.InputPartition
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.RpcUtils
+
+/**
+ * The overall strategy here is:
+ * * ContinuousMemoryStream maintains a list of records for each partition. addData() will
+ * distribute records evenly-ish across partitions.
+ * * RecordEndpoint is set up as an endpoint for executor-side
+ * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at
+ * the specified offset within the list, or null if that offset doesn't yet have a record.
+ */
+class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
+ extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
+ private implicit val formats = Serialization.formats(NoTypeHints)
+
+ protected val logicalPlan =
+ StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)
+
+ // ContinuousReader implementation
+
+ @GuardedBy("this")
+ private val records = Seq.fill(numPartitions)(new ListBuffer[A])
+
+ @GuardedBy("this")
+ private var startOffset: ContinuousMemoryStreamOffset = _
+
+ private val recordEndpoint = new RecordEndpoint()
+ @volatile private var endpointRef: RpcEndpointRef = _
+
+ def addData(data: TraversableOnce[A]): Offset = synchronized {
+ // Distribute data evenly among partition lists.
+ data.toSeq.zipWithIndex.map {
+ case (item, index) => records(index % numPartitions) += item
+ }
+
+ // The new target offset is the offset where all records in all partitions have been processed.
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
+ }
+
+ override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
+ // Inferred initial offset is position 0 in each partition.
+ startOffset = start.orElse {
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
+ }.asInstanceOf[ContinuousMemoryStreamOffset]
+ }
+
+ override def getStartOffset: Offset = synchronized {
+ startOffset
+ }
+
+ override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = {
+ ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
+ }
+
+ override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = {
+ ContinuousMemoryStreamOffset(
+ offsets.map {
+ case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
+ }.toMap
+ )
+ }
+
+ override def planInputPartitions(): ju.List[InputPartition[Row]] = {
+ synchronized {
+ val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
+ endpointRef =
+ recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
+
+ startOffset.partitionNums.map {
+ case (part, index) =>
+ new ContinuousMemoryStreamInputPartition(
+ endpointName, part, index): InputPartition[Row]
+ }.toList.asJava
+ }
+ }
+
+ override def stop(): Unit = {
+ if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
+ }
+
+ override def commit(end: Offset): Unit = {}
+
+ // ContinuousReadSupport implementation
+ // This is necessary because of how StreamTest finds the source for AddDataMemory steps.
+ def createContinuousReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): ContinuousReader = {
+ this
+ }
+
+ /**
+ * Endpoint for executors to poll for records.
+ */
+ private class RecordEndpoint extends ThreadSafeRpcEndpoint {
+ override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
+ ContinuousMemoryStream.this.synchronized {
+ val buf = records(part)
+ val record = if (buf.size <= index) None else Some(buf(index))
+
+ context.reply(record.map(Row(_)))
+ }
+ }
+ }
+}
+
+object ContinuousMemoryStream {
+ case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
+ protected val memoryStreamId = new AtomicInteger(0)
+
+ def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+
+ def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
+}
+
+/**
+ * An input partition for continuous memory stream.
+ */
+class ContinuousMemoryStreamInputPartition(
+ driverEndpointName: String,
+ partition: Int,
+ startOffset: Int) extends InputPartition[Row] {
+ override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader =
+ new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset)
+}
+
+/**
+ * An input partition reader for continuous memory stream.
+ *
+ * Polls the driver endpoint for new records.
+ */
+class ContinuousMemoryStreamInputPartitionReader(
+ driverEndpointName: String,
+ partition: Int,
+ startOffset: Int) extends ContinuousInputPartitionReader[Row] {
+ private val endpoint = RpcUtils.makeDriverRef(
+ driverEndpointName,
+ SparkEnv.get.conf,
+ SparkEnv.get.rpcEnv)
+
+ private var currentOffset = startOffset
+ private var current: Option[Row] = None
+
+ override def next(): Boolean = {
+ current = getRecord
+ while (current.isEmpty) {
+ Thread.sleep(10)
+ current = getRecord
+ }
+ currentOffset += 1
+ true
+ }
+
+ override def get(): Row = current.get
+
+ override def close(): Unit = {}
+
+ override def getOffset: ContinuousMemoryStreamPartitionOffset =
+ ContinuousMemoryStreamPartitionOffset(partition, currentOffset)
+
+ private def getRecord: Option[Row] =
+ endpoint.askSync[Option[Row]](
+ GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset)))
+}
+
+case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int])
+ extends Offset {
+ private implicit val formats = Serialization.formats(NoTypeHints)
+ override def json(): String = Serialization.write(partitionNums)
+}
+
+case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int)
+ extends PartitionOffset
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
new file mode 100644
index 0000000000000..f677f25f116a2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.PythonForeachWriter
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param converter An object to convert internal rows to target type T. Either it can be
+ * a [[ExpressionEncoder]] or a direct converter function.
+ * @tparam T The expected type of the sink.
+ */
+case class ForeachWriterProvider[T](
+ writer: ForeachWriter[T],
+ converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport {
+
+ override def createStreamWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceOptions): StreamWriter = {
+ new StreamWriter with SupportsWriteInternalRow {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+
+ override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
+ val rowConverter: InternalRow => T = converter match {
+ case Left(enc) =>
+ val boundEnc = enc.resolveAndBind(
+ schema.toAttributes,
+ SparkSession.getActiveSession.get.sessionState.analyzer)
+ boundEnc.fromRow
+ case Right(func) =>
+ func
+ }
+ ForeachWriterFactory(writer, rowConverter)
+ }
+
+ override def toString: String = "ForeachSink"
+ }
+ }
+}
+
+object ForeachWriterProvider {
+ def apply[T](
+ writer: ForeachWriter[T],
+ encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = {
+ writer match {
+ case pythonWriter: PythonForeachWriter =>
+ new ForeachWriterProvider[UnsafeRow](
+ pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow]))
+ case _ =>
+ new ForeachWriterProvider[T](writer, Left(encoder))
+ }
+ }
+}
+
+case class ForeachWriterFactory[T](
+ writer: ForeachWriter[T],
+ rowConverter: InternalRow => T)
+ extends DataWriterFactory[InternalRow] {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): ForeachDataWriter[T] = {
+ new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
+ }
+}
+
+/**
+ * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]]
+ * @param partitionId
+ * @param epochId
+ * @tparam T The type expected by the writer.
+ */
+class ForeachDataWriter[T](
+ writer: ForeachWriter[T],
+ rowConverter: InternalRow => T,
+ partitionId: Int,
+ epochId: Long)
+ extends DataWriter[InternalRow] {
+
+ // If open returns false, we should skip writing rows.
+ private val opened = writer.open(partitionId, epochId)
+
+ override def write(record: InternalRow): Unit = {
+ if (!opened) return
+
+ try {
+ writer.process(rowConverter(record))
+ } catch {
+ case t: Throwable =>
+ writer.close(t)
+ throw t
+ }
+ }
+
+ override def commit(): WriterCommitMessage = {
+ writer.close(null)
+ ForeachWriterCommitMessage
+ }
+
+ override def abort(): Unit = {}
+}
+
+/**
+ * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
+ */
+case object ForeachWriterCommitMessage extends WriterCommitMessage
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
index 248295e401a0d..e07355aa37dba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
@@ -31,7 +31,10 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
* for production-quality sinks. It's intended for use in tests.
*/
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
- def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): DataWriter[Row] = {
new PackedRowDataWriter()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
new file mode 100644
index 0000000000000..b393c48baee8d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.util.Optional
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{ManualClock, SystemClock}
+
+class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
+ extends MicroBatchReader with Logging {
+ import RateStreamProvider._
+
+ private[sources] val clock = {
+ // The option to use a manual clock is provided only for unit testing purposes.
+ if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock
+ }
+
+ private val rowsPerSecond =
+ options.get(ROWS_PER_SECOND).orElse("1").toLong
+
+ private val rampUpTimeSeconds =
+ Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String]))
+ .map(JavaUtils.timeStringAsSec(_))
+ .getOrElse(0L)
+
+ private val maxSeconds = Long.MaxValue / rowsPerSecond
+
+ if (rampUpTimeSeconds > maxSeconds) {
+ throw new ArithmeticException(
+ s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
+ s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
+ }
+
+ private[sources] val creationTimeMs = {
+ val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
+ require(session.isDefined)
+
+ val metadataLog =
+ new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) {
+ override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
+ val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+ writer.write("v" + VERSION + "\n")
+ writer.write(metadata.json)
+ writer.flush
+ }
+
+ override def deserialize(in: InputStream): LongOffset = {
+ val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+ // HDFSMetadataLog guarantees that it never creates a partial file.
+ assert(content.length != 0)
+ if (content(0) == 'v') {
+ val indexOfNewLine = content.indexOf("\n")
+ if (indexOfNewLine > 0) {
+ parseVersion(content.substring(0, indexOfNewLine), VERSION)
+ LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ }
+ }
+
+ metadataLog.get(0).getOrElse {
+ val offset = LongOffset(clock.getTimeMillis())
+ metadataLog.add(0, offset)
+ logInfo(s"Start time: $offset")
+ offset
+ }.offset
+ }
+
+ @volatile private var lastTimeMs: Long = creationTimeMs
+
+ private var start: LongOffset = _
+ private var end: LongOffset = _
+
+ override def readSchema(): StructType = SCHEMA
+
+ override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {
+ this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset]
+ this.end = end.orElse {
+ val now = clock.getTimeMillis()
+ if (lastTimeMs < now) {
+ lastTimeMs = now
+ }
+ LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs))
+ }.asInstanceOf[LongOffset]
+ }
+
+ override def getStartOffset(): Offset = {
+ if (start == null) throw new IllegalStateException("start offset not set")
+ start
+ }
+ override def getEndOffset(): Offset = {
+ if (end == null) throw new IllegalStateException("end offset not set")
+ end
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ LongOffset(json.toLong)
+ }
+
+ override def planInputPartitions(): java.util.List[InputPartition[Row]] = {
+ val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
+ val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
+ assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
+ if (endSeconds > maxSeconds) {
+ throw new ArithmeticException("Integer overflow. Max offset with " +
+ s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
+ }
+ // Fix "lastTimeMs" for recovery
+ if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) {
+ lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs
+ }
+ val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
+ val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
+ logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
+ s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
+
+ if (rangeStart == rangeEnd) {
+ return List.empty.asJava
+ }
+
+ val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
+ val relativeMsPerValue =
+ TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
+ val numPartitions = {
+ val activeSession = SparkSession.getActiveSession
+ require(activeSession.isDefined)
+ Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
+ .map(_.toInt)
+ .getOrElse(activeSession.get.sparkContext.defaultParallelism)
+ }
+
+ (0 until numPartitions).map { p =>
+ new RateStreamMicroBatchInputPartition(
+ p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+ : InputPartition[Row]
+ }.toList.asJava
+ }
+
+ override def commit(end: Offset): Unit = {}
+
+ override def stop(): Unit = {}
+
+ override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " +
+ s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
+ s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
+}
+
+class RateStreamMicroBatchInputPartition(
+ partitionId: Int,
+ numPartitions: Int,
+ rangeStart: Long,
+ rangeEnd: Long,
+ localStartTimeMs: Long,
+ relativeMsPerValue: Double) extends InputPartition[Row] {
+
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd,
+ localStartTimeMs, relativeMsPerValue)
+}
+
+class RateStreamMicroBatchInputPartitionReader(
+ partitionId: Int,
+ numPartitions: Int,
+ rangeStart: Long,
+ rangeEnd: Long,
+ localStartTimeMs: Long,
+ relativeMsPerValue: Double) extends InputPartitionReader[Row] {
+ private var count: Long = 0
+
+ override def next(): Boolean = {
+ rangeStart + partitionId + numPartitions * count < rangeEnd
+ }
+
+ override def get(): Row = {
+ val currValue = rangeStart + partitionId + numPartitions * count
+ count += 1
+ val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
+ Row(
+ DateTimeUtils.toJavaTimestamp(
+ DateTimeUtils.fromMillis(relative + localStartTimeMs)),
+ currValue
+ )
+ }
+
+ override def close(): Unit = {}
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
new file mode 100644
index 0000000000000..6bdd492f0cb35
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util.Optional
+
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
+import org.apache.spark.sql.types._
+
+/**
+ * A source that generates increment long values with timestamps. Each generated row has two
+ * columns: a timestamp column for the generated time and an auto increment long column starting
+ * with 0L.
+ *
+ * This source supports the following options:
+ * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
+ * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
+ * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
+ * seconds.
+ * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
+ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
+ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
+ */
+class RateStreamProvider extends DataSourceV2
+ with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister {
+ import RateStreamProvider._
+
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): MicroBatchReader = {
+ if (options.get(ROWS_PER_SECOND).isPresent) {
+ val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
+ if (rowsPerSecond <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
+ }
+ }
+
+ if (options.get(RAMP_UP_TIME).isPresent) {
+ val rampUpTimeSeconds =
+ JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get())
+ if (rampUpTimeSeconds < 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative")
+ }
+ }
+
+ if (options.get(NUM_PARTITIONS).isPresent) {
+ val numPartitions = options.get(NUM_PARTITIONS).get().toInt
+ if (numPartitions <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
+ }
+ }
+
+ if (schema.isPresent) {
+ throw new AnalysisException("The rate source does not support a user-specified schema.")
+ }
+
+ new RateStreamMicroBatchReader(options, checkpointLocation)
+ }
+
+ override def createContinuousReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options)
+
+ override def shortName(): String = "rate"
+}
+
+object RateStreamProvider {
+ val SCHEMA =
+ StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
+
+ val VERSION = 1
+
+ val NUM_PARTITIONS = "numPartitions"
+ val ROWS_PER_SECOND = "rowsPerSecond"
+ val RAMP_UP_TIME = "rampUpTime"
+
+ /** Calculate the end value we will emit at the time `seconds`. */
+ def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
+ // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
+ // Then speedDeltaPerSecond = 2
+ //
+ // seconds = 0 1 2 3 4 5 6
+ // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
+ // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
+ val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
+ if (seconds <= rampUpTimeSeconds) {
+ // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
+ // avoid overflow
+ if (seconds % 2 == 1) {
+ (seconds + 1) / 2 * speedDeltaPerSecond * seconds
+ } else {
+ seconds / 2 * speedDeltaPerSecond * (seconds + 1)
+ }
+ } else {
+ // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
+ val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
+ rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
deleted file mode 100644
index 1315885da8a6f..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import java.util.Optional
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-import org.json4s.DefaultFormats
-import org.json4s.jackson.Serialization
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
-import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- * This is a temporary register as we build out v2 migration. Microbatch read support should
- * be implemented in the same register as v1.
- */
-class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
- override def createMicroBatchReader(
- schema: Optional[StructType],
- checkpointLocation: String,
- options: DataSourceOptions): MicroBatchReader = {
- new RateStreamMicroBatchReader(options)
- }
-
- override def shortName(): String = "ratev2"
-}
-
-class RateStreamMicroBatchReader(options: DataSourceOptions)
- extends MicroBatchReader {
- implicit val defaultFormats: DefaultFormats = DefaultFormats
-
- val clock = {
- // The option to use a manual clock is provided only for unit testing purposes.
- if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
- else new SystemClock
- }
-
- private val numPartitions =
- options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
- private val rowsPerSecond =
- options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
-
- // The interval (in milliseconds) between rows in each partition.
- // e.g. if there are 4 global rows per second, and 2 partitions, each partition
- // should output rows every (1000 * 2 / 4) = 500 ms.
- private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
-
- override def readSchema(): StructType = {
- StructType(
- StructField("timestamp", TimestampType, false) ::
- StructField("value", LongType, false) :: Nil)
- }
-
- val creationTimeMs = clock.getTimeMillis()
-
- private var start: RateStreamOffset = _
- private var end: RateStreamOffset = _
-
- override def setOffsetRange(
- start: Optional[Offset],
- end: Optional[Offset]): Unit = {
- this.start = start.orElse(
- RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
- .asInstanceOf[RateStreamOffset]
-
- this.end = end.orElse {
- val currentTime = clock.getTimeMillis()
- RateStreamOffset(
- this.start.partitionToValueAndRunTimeMs.map {
- case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
- // Calculate the number of rows we should advance in this partition (based on the
- // current time), and output a corresponding offset.
- val readInterval = currentTime - currentReadTime
- val numNewRows = readInterval / msPerPartitionBetweenRows
- if (numNewRows <= 0) {
- startOffset
- } else {
- (part, ValueRunTimeMsPair(
- currentVal + (numNewRows * numPartitions),
- currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
- }
- }
- )
- }.asInstanceOf[RateStreamOffset]
- }
-
- override def getStartOffset(): Offset = {
- if (start == null) throw new IllegalStateException("start offset not set")
- start
- }
- override def getEndOffset(): Offset = {
- if (end == null) throw new IllegalStateException("end offset not set")
- end
- }
-
- override def deserializeOffset(json: String): Offset = {
- RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
- }
-
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
- val startMap = start.partitionToValueAndRunTimeMs
- val endMap = end.partitionToValueAndRunTimeMs
- endMap.keys.toSeq.map { part =>
- val ValueRunTimeMsPair(endVal, _) = endMap(part)
- val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
-
- val packedRows = mutable.ListBuffer[(Long, Long)]()
- var outVal = startVal + numPartitions
- var outTimeMs = startTimeMs
- while (outVal <= endVal) {
- packedRows.append((outTimeMs, outVal))
- outVal += numPartitions
- outTimeMs += msPerPartitionBetweenRows
- }
-
- RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
- }.toList.asJava
- }
-
- override def commit(end: Offset): Unit = {}
- override def stop(): Unit = {}
-}
-
-case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
- override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
-}
-
-class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
- var currentIndex = -1
-
- override def next(): Boolean = {
- // Return true as long as the new index is in the seq.
- currentIndex += 1
- currentIndex < vals.size
- }
-
- override def get(): Row = {
- Row(
- DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
- vals(currentIndex)._2)
- }
-
- override def close(): Unit = {}
-}
-
-object RateStreamSourceV2 {
- val NUM_PARTITIONS = "numPartitions"
- val ROWS_PER_SECOND = "rowsPerSecond"
-
- private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
- RateStreamOffset(
- Range(0, numPartitions).map { i =>
- // Note that the starting offset is exclusive, so we have to decrement the starting value
- // by the increment that will later be applied. The first row output in each
- // partition will have a value equal to the partition index.
- (i,
- ValueRunTimeMsPair(
- (i - numPartitions).toLong,
- creationTimeMs))
- }.toMap)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 3411edbc53412..47b482007822d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
-import org.apache.spark.sql.execution.streaming.Sink
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2}
+import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
@@ -39,13 +40,13 @@ import org.apache.spark.sql.types.StructType
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
+class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging {
override def createStreamWriter(
queryId: String,
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
- new MemoryStreamWriter(this, mode)
+ new MemoryStreamWriter(this, mode, options)
}
private case class AddedData(batchId: Long, data: Array[Row])
@@ -54,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()
+ /** The number of rows in this MemorySink. */
+ private var numRows = 0
+
/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
@@ -67,6 +71,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
batches.lastOption.toSeq.flatten(_.data)
}
+ def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
+ batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
+ }
+
def toDebugString: String = synchronized {
batches.map { case AddedData(batchId, data) =>
val dataStr = try data.mkString(" ") catch {
@@ -76,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
}.mkString("\n")
}
- def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
+ def write(
+ batchId: Long,
+ outputMode: OutputMode,
+ newRows: Array[Row],
+ sinkCapacity: Int): Unit = {
val notCommitted = synchronized {
latestBatchId.isEmpty || batchId > latestBatchId.get
}
@@ -84,19 +96,26 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
- val rows = AddedData(batchId, newRows)
- synchronized { batches += rows }
+ synchronized {
+ val rowsToAdd =
+ truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
+ val rows = AddedData(batchId, rowsToAdd)
+ batches += rows
+ numRows += rowsToAdd.length
+ }
case Complete =>
- val rows = AddedData(batchId, newRows)
synchronized {
+ val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
+ val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
+ numRows = rowsToAdd.length
}
case _ =>
throw new IllegalArgumentException(
- s"Output mode $outputMode is not supported by MemorySink")
+ s"Output mode $outputMode is not supported by MemorySinkV2")
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
@@ -105,23 +124,30 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging {
def clear(): Unit = synchronized {
batches.clear()
+ numRows = 0
}
- override def toString(): String = "MemorySink"
+ override def toString(): String = "MemorySinkV2"
}
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}
-class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
+class MemoryWriter(
+ sink: MemorySinkV2,
+ batchId: Long,
+ outputMode: OutputMode,
+ options: DataSourceOptions)
extends DataSourceWriter with Logging {
+ val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
+
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
def commit(messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
- sink.write(batchId, outputMode, newRows)
+ sink.write(batchId, outputMode, newRows, sinkCapacity)
}
override def abort(messages: Array[WriterCommitMessage]): Unit = {
@@ -129,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
}
}
-class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
+class MemoryStreamWriter(
+ val sink: MemorySinkV2,
+ outputMode: OutputMode,
+ options: DataSourceOptions)
extends StreamWriter {
+ val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
+
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
- sink.write(epochId, outputMode, newRows)
+ sink.write(epochId, outputMode, newRows, sinkCapacity)
}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
@@ -147,7 +178,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
}
case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
- def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
+ override def createDataWriter(
+ partitionId: Int,
+ attemptNumber: Int,
+ epochId: Long): DataWriter[Row] = {
new MemoryDataWriter(partitionId, outputMode)
}
}
@@ -172,10 +206,10 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
/**
- * Used to query the data that has been written into a [[MemorySink]].
+ * Used to query the data that has been written into a [[MemorySinkV2]].
*/
case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
- private val sizePerRow = output.map(_.dataType.defaultSize).sum
+ private val sizePerRow = EstimationUtils.getSizePerRow(output)
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
new file mode 100644
index 0000000000000..91e3b7179c34a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io.{BufferedReader, InputStreamReader, IOException}
+import java.net.Socket
+import java.sql.Timestamp
+import java.text.SimpleDateFormat
+import java.util.{Calendar, List => JList, Locale, Optional}
+import java.util.concurrent.atomic.AtomicBoolean
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.streaming.LongOffset
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+object TextSocketMicroBatchReader {
+ val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
+ val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
+ StructField("timestamp", TimestampType) :: Nil)
+ val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
+}
+
+/**
+ * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and
+ * debugging. This MicroBatchReader will *not* work in production applications due to multiple
+ * reasons, including no support for fault recovery.
+ */
+class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging {
+
+ private var startOffset: Offset = _
+ private var endOffset: Offset = _
+
+ private val host: String = options.get("host").get()
+ private val port: Int = options.get("port").get().toInt
+
+ @GuardedBy("this")
+ private var socket: Socket = null
+
+ @GuardedBy("this")
+ private var readThread: Thread = null
+
+ /**
+ * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
+ * Stored in a ListBuffer to facilitate removing committed batches.
+ */
+ @GuardedBy("this")
+ private val batches = new ListBuffer[(String, Timestamp)]
+
+ @GuardedBy("this")
+ private var currentOffset: LongOffset = LongOffset(-1L)
+
+ @GuardedBy("this")
+ private var lastOffsetCommitted: LongOffset = LongOffset(-1L)
+
+ private val initialized: AtomicBoolean = new AtomicBoolean(false)
+
+ /** This method is only used for unit test */
+ private[sources] def getCurrentOffset(): LongOffset = synchronized {
+ currentOffset.copy()
+ }
+
+ private def initialize(): Unit = synchronized {
+ socket = new Socket(host, port)
+ val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
+ readThread = new Thread(s"TextSocketSource($host, $port)") {
+ setDaemon(true)
+
+ override def run(): Unit = {
+ try {
+ while (true) {
+ val line = reader.readLine()
+ if (line == null) {
+ // End of file reached
+ logWarning(s"Stream closed by $host:$port")
+ return
+ }
+ TextSocketMicroBatchReader.this.synchronized {
+ val newData = (line,
+ Timestamp.valueOf(
+ TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime()))
+ )
+ currentOffset += 1
+ batches.append(newData)
+ }
+ }
+ } catch {
+ case e: IOException =>
+ }
+ }
+ }
+ readThread.start()
+ }
+
+ override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized {
+ startOffset = start.orElse(LongOffset(-1L))
+ endOffset = end.orElse(currentOffset)
+ }
+
+ override def getStartOffset(): Offset = {
+ Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set"))
+ }
+
+ override def getEndOffset(): Offset = {
+ Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set"))
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ LongOffset(json.toLong)
+ }
+
+ override def readSchema(): StructType = {
+ if (options.getBoolean("includeTimestamp", false)) {
+ TextSocketMicroBatchReader.SCHEMA_TIMESTAMP
+ } else {
+ TextSocketMicroBatchReader.SCHEMA_REGULAR
+ }
+ }
+
+ override def planInputPartitions(): JList[InputPartition[Row]] = {
+ assert(startOffset != null && endOffset != null,
+ "start offset and end offset should already be set before create read tasks.")
+
+ val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1
+ val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1
+
+ // Internal buffer only holds the batches after lastOffsetCommitted
+ val rawList = synchronized {
+ if (initialized.compareAndSet(false, true)) {
+ initialize()
+ }
+
+ val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
+ val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
+ batches.slice(sliceStart, sliceEnd)
+ }
+
+ assert(SparkSession.getActiveSession.isDefined)
+ val spark = SparkSession.getActiveSession.get
+ val numPartitions = spark.sparkContext.defaultParallelism
+
+ val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)])
+ rawList.zipWithIndex.foreach { case (r, idx) =>
+ slices(idx % numPartitions).append(r)
+ }
+
+ (0 until numPartitions).map { i =>
+ val slice = slices(i)
+ new InputPartition[Row] {
+ override def createPartitionReader(): InputPartitionReader[Row] =
+ new InputPartitionReader[Row] {
+ private var currentIdx = -1
+
+ override def next(): Boolean = {
+ currentIdx += 1
+ currentIdx < slice.size
+ }
+
+ override def get(): Row = {
+ Row(slice(currentIdx)._1, slice(currentIdx)._2)
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+ }.toList.asJava
+ }
+
+ override def commit(end: Offset): Unit = synchronized {
+ val newOffset = LongOffset.convert(end).getOrElse(
+ sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
+ s"originate with an instance of this class")
+ )
+
+ val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
+
+ if (offsetDiff < 0) {
+ sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
+ }
+
+ batches.trimStart(offsetDiff)
+ lastOffsetCommitted = newOffset
+ }
+
+ /** Stop this source. */
+ override def stop(): Unit = synchronized {
+ if (socket != null) {
+ try {
+ // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to
+ // stop the readThread is to close the socket.
+ socket.close()
+ } catch {
+ case e: IOException =>
+ }
+ socket = null
+ }
+ }
+
+ override def toString: String = s"TextSocketV2[host: $host, port: $port]"
+}
+
+class TextSocketSourceProvider extends DataSourceV2
+ with MicroBatchReadSupport with DataSourceRegister with Logging {
+
+ private def checkParameters(params: DataSourceOptions): Unit = {
+ logWarning("The socket source should not be used for production applications! " +
+ "It does not support recovery.")
+ if (!params.get("host").isPresent) {
+ throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
+ }
+ if (!params.get("port").isPresent) {
+ throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
+ }
+ Try {
+ params.get("includeTimestamp").orElse("false").toBoolean
+ } match {
+ case Success(_) =>
+ case Failure(_) =>
+ throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
+ }
+ }
+
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): MicroBatchReader = {
+ checkParameters(options)
+ if (schema.isPresent) {
+ throw new AnalysisException("The socket source does not support a user-specified schema.")
+ }
+
+ new TextSocketMicroBatchReader(options)
+ }
+
+ /** String that represents the format that this data source provider uses. */
+ override def shortName(): String = "socket"
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 3f5002a4e6937..118c82aa75e68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -17,23 +17,24 @@
package org.apache.spark.sql.execution.streaming.state
-import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException}
-import java.nio.channels.ClosedChannelException
+import java.io._
import java.util.Locale
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.util.Random
import scala.util.control.NonFatal
import com.google.common.io.ByteStreams
+import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs._
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.LZ4CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{SizeEstimator, Utils}
@@ -87,10 +88,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
case object ABORTED extends STATE
private val newVersion = version + 1
- private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
- private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true))
@volatile private var state: STATE = UPDATING
- @volatile private var finalDeltaFile: Path = null
+ private val finalDeltaFile: Path = deltaFile(newVersion)
+ private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true)
+ private lazy val compressedStream = compressStream(deltaFileStream)
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId
@@ -103,14 +104,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
val keyCopy = key.copy()
val valueCopy = value.copy()
mapToUpdate.put(keyCopy, valueCopy)
- writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy)
+ writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy)
}
override def remove(key: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or aborted")
val prevValue = mapToUpdate.remove(key)
if (prevValue != null) {
- writeRemoveToDeltaFile(tempDeltaFileStream, key)
+ writeRemoveToDeltaFile(compressedStream, key)
}
}
@@ -126,8 +127,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
verify(state == UPDATING, "Cannot commit after already committed or aborted")
try {
- finalizeDeltaFile(tempDeltaFileStream)
- finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile)
+ commitUpdates(newVersion, mapToUpdate, compressedStream)
state = COMMITTED
logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile")
newVersion
@@ -140,23 +140,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
/** Abort all the updates made on this store. This store will not be usable any more. */
override def abort(): Unit = {
- verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
- try {
+ // This if statement is to ensure that files are deleted only if there are changes to the
+ // StateStore. We have two StateStores for each task, one which is used only for reading, and
+ // the other used for read+write. We don't want the read-only to delete state files.
+ if (state == UPDATING) {
+ state = ABORTED
+ cancelDeltaFile(compressedStream, deltaFileStream)
+ } else {
state = ABORTED
- if (tempDeltaFileStream != null) {
- tempDeltaFileStream.close()
- }
- if (tempDeltaFile != null) {
- fs.delete(tempDeltaFile, true)
- }
- } catch {
- case c: ClosedChannelException =>
- // This can happen when underlying file output stream has been closed before the
- // compression stream.
- logDebug(s"Error aborting version $newVersion into $this", c)
-
- case e: Exception =>
- logWarning(s"Error aborting version $newVersion into $this", e)
}
logInfo(s"Aborted version $newVersion for $this")
}
@@ -212,7 +203,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
this.valueSchema = valueSchema
this.storeConf = storeConf
this.hadoopConf = hadoopConf
- fs.mkdirs(baseDir)
+ fm.mkdirs(baseDir)
}
override def stateStoreId: StateStoreId = stateStoreId_
@@ -251,31 +242,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
private lazy val baseDir = stateStoreId.storeCheckpointLocation()
- private lazy val fs = baseDir.getFileSystem(hadoopConf)
+ private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf)
private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
- /** Commit a set of updates to the store with the given new version */
- private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = {
+ private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = {
synchronized {
- val finalDeltaFile = deltaFile(newVersion)
-
- // scalastyle:off
- // Renaming a file atop an existing one fails on HDFS
- // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html).
- // Hence we should either skip the rename step or delete the target file. Because deleting the
- // target file will break speculation, skipping the rename step is the only choice. It's still
- // semantically correct because Structured Streaming requires rerunning a batch should
- // generate the same output. (SPARK-19677)
- // scalastyle:on
- if (fs.exists(finalDeltaFile)) {
- fs.delete(tempDeltaFile, true)
- } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) {
- throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile")
- }
+ finalizeDeltaFile(output)
loadedMaps.put(newVersion, map)
- finalDeltaFile
}
}
@@ -303,38 +278,49 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
if (loadedCurrentVersionMap.isDefined) {
return loadedCurrentVersionMap.get
}
- val snapshotCurrentVersionMap = readSnapshotFile(version)
- if (snapshotCurrentVersionMap.isDefined) {
- synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) }
- return snapshotCurrentVersionMap.get
- }
- // Find the most recent map before this version that we can.
- // [SPARK-22305] This must be done iteratively to avoid stack overflow.
- var lastAvailableVersion = version
- var lastAvailableMap: Option[MapType] = None
- while (lastAvailableMap.isEmpty) {
- lastAvailableVersion -= 1
+ logWarning(s"The state for version $version doesn't exist in loadedMaps. " +
+ "Reading snapshot file and delta files if needed..." +
+ "Note that this is normal for the first batch of starting query.")
- if (lastAvailableVersion <= 0) {
- // Use an empty map for versions 0 or less.
- lastAvailableMap = Some(new MapType)
- } else {
- lastAvailableMap =
- synchronized { loadedMaps.get(lastAvailableVersion) }
- .orElse(readSnapshotFile(lastAvailableVersion))
+ val (result, elapsedMs) = Utils.timeTakenMs {
+ val snapshotCurrentVersionMap = readSnapshotFile(version)
+ if (snapshotCurrentVersionMap.isDefined) {
+ synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) }
+ return snapshotCurrentVersionMap.get
+ }
+
+ // Find the most recent map before this version that we can.
+ // [SPARK-22305] This must be done iteratively to avoid stack overflow.
+ var lastAvailableVersion = version
+ var lastAvailableMap: Option[MapType] = None
+ while (lastAvailableMap.isEmpty) {
+ lastAvailableVersion -= 1
+
+ if (lastAvailableVersion <= 0) {
+ // Use an empty map for versions 0 or less.
+ lastAvailableMap = Some(new MapType)
+ } else {
+ lastAvailableMap =
+ synchronized { loadedMaps.get(lastAvailableVersion) }
+ .orElse(readSnapshotFile(lastAvailableVersion))
+ }
}
- }
- // Load all the deltas from the version after the last available one up to the target version.
- // The last available version is the one with a full snapshot, so it doesn't need deltas.
- val resultMap = new MapType(lastAvailableMap.get)
- for (deltaVersion <- lastAvailableVersion + 1 to version) {
- updateFromDeltaFile(deltaVersion, resultMap)
+ // Load all the deltas from the version after the last available one up to the target version.
+ // The last available version is the one with a full snapshot, so it doesn't need deltas.
+ val resultMap = new MapType(lastAvailableMap.get)
+ for (deltaVersion <- lastAvailableVersion + 1 to version) {
+ updateFromDeltaFile(deltaVersion, resultMap)
+ }
+
+ synchronized { loadedMaps.put(version, resultMap) }
+ resultMap
}
- synchronized { loadedMaps.put(version, resultMap) }
- resultMap
+ logDebug(s"Loading state for $version takes $elapsedMs ms.")
+
+ result
}
private def writeUpdateToDeltaFile(
@@ -365,7 +351,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
val fileToRead = deltaFile(version)
var input: DataInputStream = null
val sourceStream = try {
- fs.open(fileToRead)
+ fm.open(fileToRead)
} catch {
case f: FileNotFoundException =>
throw new IllegalStateException(
@@ -412,12 +398,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}
private def writeSnapshotFile(version: Long, map: MapType): Unit = {
- val fileToWrite = snapshotFile(version)
- val tempFile =
- new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}")
+ val targetFile = snapshotFile(version)
+ var rawOutput: CancellableFSDataOutputStream = null
var output: DataOutputStream = null
- Utils.tryWithSafeFinally {
- output = compressStream(fs.create(tempFile, false))
+ try {
+ rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true)
+ output = compressStream(rawOutput)
val iter = map.entrySet().iterator()
while(iter.hasNext) {
val entry = iter.next()
@@ -429,16 +415,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
output.write(valueBytes)
}
output.writeInt(-1)
- } {
- if (output != null) output.close()
+ output.close()
+ } catch {
+ case e: Throwable =>
+ cancelDeltaFile(compressedStream = output, rawStream = rawOutput)
+ throw e
}
- if (fs.exists(fileToWrite)) {
- // Skip rename if the file is alreayd created.
- fs.delete(tempFile, true)
- } else if (!fs.rename(tempFile, fileToWrite)) {
- throw new IOException(s"Failed to rename $tempFile to $fileToWrite")
+ logInfo(s"Written snapshot file for version $version of $this at $targetFile")
+ }
+
+ /**
+ * Try to cancel the underlying stream and safely close the compressed stream.
+ *
+ * @param compressedStream the compressed stream.
+ * @param rawStream the underlying stream which needs to be cancelled.
+ */
+ private def cancelDeltaFile(
+ compressedStream: DataOutputStream,
+ rawStream: CancellableFSDataOutputStream): Unit = {
+ try {
+ if (rawStream != null) rawStream.cancel()
+ IOUtils.closeQuietly(compressedStream)
+ } catch {
+ case e: FSError if e.getCause.isInstanceOf[IOException] =>
+ // Closing the compressedStream causes the stream to write/flush flush data into the
+ // rawStream. Since the rawStream is already closed, there may be errors.
+ // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps
+ // IOException into FSError.
}
- logInfo(s"Written snapshot file for version $version of $this at $fileToWrite")
}
private def readSnapshotFile(version: Long): Option[MapType] = {
@@ -447,7 +451,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
var input: DataInputStream = null
try {
- input = decompressStream(fs.open(fileToRead))
+ input = decompressStream(fm.open(fileToRead))
var eof = false
while (!eof) {
@@ -495,7 +499,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
/** Perform a snapshot of the store to allow delta files to be consolidated */
private def doSnapshot(): Unit = {
try {
- val files = fetchFiles()
+ val (files, e1) = Utils.timeTakenMs(fetchFiles())
+ logDebug(s"fetchFiles() took $e1 ms.")
+
if (files.nonEmpty) {
val lastVersion = files.last.version
val deltaFilesForLastVersion =
@@ -503,12 +509,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
synchronized { loadedMaps.get(lastVersion) } match {
case Some(map) =>
if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) {
- writeSnapshotFile(lastVersion, map)
+ val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map))
+ logDebug(s"writeSnapshotFile() took $e2 ms.")
}
case None =>
// The last map is not loaded, probably some other instance is in charge
}
-
}
} catch {
case NonFatal(e) =>
@@ -523,7 +529,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
*/
private[state] def cleanup(): Unit = {
try {
- val files = fetchFiles()
+ val (files, e1) = Utils.timeTakenMs(fetchFiles())
+ logDebug(s"fetchFiles() took $e1 ms.")
+
if (files.nonEmpty) {
val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain
if (earliestVersionToRetain > 0) {
@@ -533,9 +541,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
mapsToRemove.foreach(loadedMaps.remove)
}
val filesToDelete = files.filter(_.version < earliestFileToRetain.version)
- filesToDelete.foreach { f =>
- fs.delete(f.path, true)
+ val (_, e2) = Utils.timeTakenMs {
+ filesToDelete.foreach { f =>
+ fm.delete(f.path)
+ }
}
+ logDebug(s"deleting files took $e2 ms.")
logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " +
filesToDelete.mkString(", "))
}
@@ -576,7 +587,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
/** Fetch all the files that back the store */
private def fetchFiles(): Seq[StoreFile] = {
val files: Seq[FileStatus] = try {
- fs.listStatus(baseDir)
+ fm.list(baseDir)
} catch {
case _: java.io.FileNotFoundException =>
Seq.empty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d1d9f95cb0977..7eb68c21569ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -459,7 +459,6 @@ object StateStore extends Logging {
private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
val env = SparkEnv.get
if (env != null) {
- logInfo("Env is not null")
val isDriver =
env.executorId == SparkContext.DRIVER_IDENTIFIER ||
env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
@@ -467,13 +466,12 @@ object StateStore extends Logging {
// as SparkContext + SparkEnv may have been restarted. Hence, when running in driver,
// always recreate the reference.
if (isDriver || _coordRef == null) {
- logInfo("Getting StateStoreCoordinatorRef")
+ logDebug("Getting StateStoreCoordinatorRef")
_coordRef = StateStoreCoordinatorRef.forExecutor(env)
}
logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
Some(_coordRef)
} else {
- logInfo("Env is null")
_coordRef = null
None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 01d8e75980993..3f11b8f79943c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.streaming.continuous.EpochTracker
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
@@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
StateStoreId(checkpointLocation, operatorId, partition.index),
queryRunId)
+ // If we're in continuous processing mode, we should get the store version for the current
+ // epoch rather than the one at planning time.
+ val currentVersion = EpochTracker.getCurrentEpoch match {
+ case None => storeVersion
+ case Some(value) => value
+ }
+
store = StateStore.get(
- storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
+ storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion,
storeConf, hadoopConfBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index b9b07a2e688f9..6759fb42b4052 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
-import org.apache.spark.util.{CompletionIterator, NextIterator}
+import org.apache.spark.util.{CompletionIterator, NextIterator, Utils}
/** Used to identify the state store for a given operator. */
@@ -97,12 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
}
/** Records the duration of running `body` for the next query progress update. */
- protected def timeTakenMs(body: => Unit): Long = {
- val startTime = System.nanoTime()
- val result = body
- val endTime = System.nanoTime()
- math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
- }
+ protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2
/**
* Set the SQL metrics related to the state store.
@@ -126,6 +121,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
name -> SQLMetrics.createTimingMetric(sparkContext, desc)
}.toMap
}
+
+ /**
+ * Should the MicroBatchExecution run another batch based on this stateful operator and the
+ * current updated metadata.
+ */
+ def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false
}
/** An operator that supports watermark. */
@@ -340,37 +341,35 @@ case class StateStoreSaveExec(
// Update and output modified rows from the StateStore.
case Some(Update) =>
- val updatesStartTimeNs = System.nanoTime
-
- new Iterator[InternalRow] {
-
+ new NextIterator[InternalRow] {
// Filter late date using watermark if specified
private[this] val baseIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
+ private val updatesStartTimeNs = System.nanoTime
- override def hasNext: Boolean = {
- if (!baseIterator.hasNext) {
- allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
-
- // Remove old aggregates if watermark specified
- allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
- commitTimeMs += timeTakenMs { store.commit() }
- setStoreMetrics(store)
- false
+ override protected def getNext(): InternalRow = {
+ if (baseIterator.hasNext) {
+ val row = baseIterator.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key, row)
+ numOutputRows += 1
+ numUpdatedStateRows += 1
+ row
} else {
- true
+ finished = true
+ null
}
}
- override def next(): InternalRow = {
- val row = baseIterator.next().asInstanceOf[UnsafeRow]
- val key = getKey(row)
- store.put(key, row)
- numOutputRows += 1
- numUpdatedStateRows += 1
- row
+ override protected def close(): Unit = {
+ allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+
+ // Remove old aggregates if watermark specified
+ allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+ commitTimeMs += timeTakenMs { store.commit() }
+ setStoreMetrics(store)
}
}
@@ -390,6 +389,12 @@ case class StateStoreSaveExec(
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
}
}
+
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ (outputMode.contains(Append) || outputMode.contains(Update)) &&
+ eventTimeWatermark.isDefined &&
+ newMetadata.batchWatermarkMs > eventTimeWatermark.get
+ }
}
/** Physical operator for executing streaming Deduplicate. */
@@ -456,6 +461,10 @@ case class StreamingDeduplicateExec(
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+ }
}
object StreamingDeduplicateExec {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index e751ce39cd5d7..bf46bc4cf904d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
val failed = new mutable.ArrayBuffer[SQLExecutionUIData]()
sqlStore.executionsList().foreach { e =>
- val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING }
+ val isRunning = e.completionTime.isEmpty ||
+ e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING }
val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED }
if (isRunning) {
running += e
@@ -57,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
_content ++=
new RunningExecutionTable(
parent, s"Running Queries (${running.size})", currentTime,
- running.sortBy(_.submissionTime).reverse).toNodeSeq
+ running.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
if (completed.nonEmpty) {
_content ++=
new CompletedExecutionTable(
parent, s"Completed Queries (${completed.size})", currentTime,
- completed.sortBy(_.submissionTime).reverse).toNodeSeq
+ completed.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
if (failed.nonEmpty) {
_content ++=
new FailedExecutionTable(
parent, s"Failed Queries (${failed.size})", currentTime,
- failed.sortBy(_.submissionTime).reverse).toNodeSeq
+ failed.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
_content
}
@@ -110,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
}
- UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000))
+ UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000))
}
}
@@ -132,7 +133,10 @@ private[ui] abstract class ExecutionTable(
protected def header: Seq[String]
- protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = {
+ protected def row(
+ request: HttpServletRequest,
+ currentTime: Long,
+ executionUIData: SQLExecutionUIData): Seq[Node] = {
val submissionTime = executionUIData.submissionTime
val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) -
submissionTime
@@ -140,7 +144,7 @@ private[ui] abstract class ExecutionTable(
def jobLinks(status: JobExecutionStatus): Seq[Node] = {
executionUIData.jobs.flatMap { case (jobId, jobStatus) =>
if (jobStatus == status) {
- [{jobId.toString}]
+ [{jobId.toString}]
} else {
None
}
@@ -152,7 +156,7 @@ private[ui] abstract class ExecutionTable(
{executionUIData.executionId.toString}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 53fb9a0cc21cf..d254af400a7cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -88,7 +88,7 @@ class SQLAppStatusListener(
exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
exec.stages ++= event.stageIds.toSet
- update(exec)
+ update(exec, force = true)
}
override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = {
@@ -289,7 +289,7 @@ class SQLAppStatusListener(
private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = {
val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event
Option(liveExecutions.get(executionId)).foreach { exec =>
- exec.driverAccumUpdates = accumUpdates.toMap
+ exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates
update(exec)
}
}
@@ -308,11 +308,13 @@ class SQLAppStatusListener(
})
}
- private def update(exec: LiveExecutionData): Unit = {
+ private def update(exec: LiveExecutionData, force: Boolean = false): Unit = {
val now = System.nanoTime()
if (exec.endEvents >= exec.jobs.size + 1) {
exec.write(kvstore, now)
liveExecutions.remove(exec.executionId)
+ } else if (force) {
+ exec.write(kvstore, now)
} else if (liveUpdatePeriodNs >= 0) {
if (now - exec.lastWriteTime > liveUpdatePeriodNs) {
exec.write(kvstore, now)
@@ -334,7 +336,10 @@ class SQLAppStatusListener(
val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L)
val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined)
- toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) }
+ toDelete.foreach { e =>
+ kvstore.delete(e.getClass(), e.executionId)
+ kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
index 9a76584717f42..241001a857c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala
@@ -54,6 +54,10 @@ class SQLAppStatusStore(
store.count(classOf[SQLExecutionUIData])
}
+ def planGraphCount(): Long = {
+ store.count(classOf[SparkPlanGraphWrapper])
+ }
+
def executionMetrics(executionId: Long): Map[Long, String] = {
def metricsFromStore(): Option[Map[Long, String]] = {
val exec = store.read(classOf[SQLExecutionUIData], executionId)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 800a2ea3f3996..626f39d9e95cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -112,9 +112,11 @@ case class WindowExec(
*
* @param frame to evaluate. This can either be a Row or Range frame.
* @param bound with respect to the row.
+ * @param timeZone the session local timezone for time related calculations.
* @return a bound ordering object.
*/
- private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = {
+ private[this] def createBoundOrdering(
+ frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
(frame, bound) match {
case (RowFrame, CurrentRow) =>
RowBoundOrdering(0)
@@ -144,7 +146,7 @@ case class WindowExec(
val boundExpr = (expr.dataType, boundOffset.dataType) match {
case (DateType, IntegerType) => DateAdd(expr, boundOffset)
case (TimestampType, CalendarIntervalType) =>
- TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone))
+ TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a== b => Add(expr, boundOffset)
}
val bound = newMutableProjection(boundExpr :: Nil, child.output)
@@ -197,6 +199,7 @@ case class WindowExec(
// Map the groups to a (unbound) expression and frame factory pair.
var numExpressions = 0
+ val timeZone = conf.sessionLocalTimeZone
framedFunctions.toSeq.map {
case (key, (expressions, functionSeq)) =>
val ordinal = numExpressions
@@ -237,7 +240,7 @@ case class WindowExec(
new UnboundedPrecedingWindowFunctionFrame(
target,
processor,
- createBoundOrdering(frameType, upper))
+ createBoundOrdering(frameType, upper, timeZone))
}
// Shrinking Frame.
@@ -246,7 +249,7 @@ case class WindowExec(
new UnboundedFollowingWindowFunctionFrame(
target,
processor,
- createBoundOrdering(frameType, lower))
+ createBoundOrdering(frameType, lower, timeZone))
}
// Moving Frame.
@@ -255,8 +258,8 @@ case class WindowExec(
new SlidingWindowFunctionFrame(
target,
processor,
- createBoundOrdering(frameType, lower),
- createBoundOrdering(frameType, upper))
+ createBoundOrdering(frameType, lower, timeZone),
+ createBoundOrdering(frameType, upper, timeZone))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index 1caa243f8d118..cd819bab1b14c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._
* Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
* }}}
*
+ * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding,
+ * unboundedFollowing) is used by default. When ordering is defined, a growing window frame
+ * (rangeFrame, unboundedPreceding, currentRow) is used by default.
+ *
* @since 1.4.0
*/
@InterfaceStability.Stable
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d54c02c3d06f..8551058ec58ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -132,7 +132,7 @@ object functions {
* Returns a sort expression based on ascending order of the column,
* and null values return before non-null values.
* {{{
- * df.sort(asc_nulls_last("dept"), desc("age"))
+ * df.sort(asc_nulls_first("dept"), desc("age"))
* }}}
*
* @group sort_funcs
@@ -283,6 +283,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -291,6 +294,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -299,6 +305,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -307,6 +316,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
+ * @note The function is non-deterministic because the order of collected results depends
+ * on order of rows which may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.6.0
*/
@@ -422,6 +434,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -435,6 +450,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -448,6 +466,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -459,6 +480,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -535,6 +559,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -548,6 +575,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 2.0.0
*/
@@ -561,6 +591,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -572,6 +605,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
+ * @note The function is non-deterministic because its results depends on order of rows which
+ * may be non-deterministic after a shuffle.
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -775,6 +811,7 @@ object functions {
*/
def var_pop(columnName: String): Column = var_pop(Column(columnName))
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -1033,6 +1070,17 @@ object functions {
@scala.annotation.varargs
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
+ /**
+ * Creates a new map column. The array in the first column is used for keys. The array in the
+ * second column is used for values. All elements in the array for key should not be null.
+ *
+ * @group normal_funcs
+ * @since 2.4
+ */
+ def map_from_arrays(keys: Column, values: Column): Column = withExpr {
+ MapFromArrays(keys.expr, values.expr)
+ }
+
/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
@@ -1172,7 +1220,7 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* from U[0.0, 1.0].
*
- * @note This is indeterministic when data partitions are not fixed.
+ * @note The function is non-deterministic in general case.
*
* @group normal_funcs
* @since 1.4.0
@@ -1183,6 +1231,8 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* from U[0.0, 1.0].
*
+ * @note The function is non-deterministic in general case.
+ *
* @group normal_funcs
* @since 1.4.0
*/
@@ -1192,7 +1242,7 @@ object functions {
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution.
*
- * @note This is indeterministic when data partitions are not fixed.
+ * @note The function is non-deterministic in general case.
*
* @group normal_funcs
* @since 1.4.0
@@ -1203,6 +1253,8 @@ object functions {
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution.
*
+ * @note The function is non-deterministic in general case.
+ *
* @group normal_funcs
* @since 1.4.0
*/
@@ -1211,7 +1263,7 @@ object functions {
/**
* Partition ID.
*
- * @note This is indeterministic because it depends on data partitioning and task scheduling.
+ * @note This is non-deterministic because it depends on data partitioning and task scheduling.
*
* @group normal_funcs
* @since 1.6.0
@@ -1313,8 +1365,7 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Computes the cosine inverse of the given value; the returned angle is in the range
- * 0.0 through pi.
+ * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos`
*
* @group math_funcs
* @since 1.4.0
@@ -1322,8 +1373,7 @@ object functions {
def acos(e: Column): Column = withExpr { Acos(e.expr) }
/**
- * Computes the cosine inverse of the given column; the returned angle is in the range
- * 0.0 through pi.
+ * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos`
*
* @group math_funcs
* @since 1.4.0
@@ -1331,8 +1381,7 @@ object functions {
def acos(columnName: String): Column = acos(Column(columnName))
/**
- * Computes the sine inverse of the given value; the returned angle is in the range
- * -pi/2 through pi/2.
+ * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin`
*
* @group math_funcs
* @since 1.4.0
@@ -1340,8 +1389,7 @@ object functions {
def asin(e: Column): Column = withExpr { Asin(e.expr) }
/**
- * Computes the sine inverse of the given column; the returned angle is in the range
- * -pi/2 through pi/2.
+ * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin`
*
* @group math_funcs
* @since 1.4.0
@@ -1349,8 +1397,7 @@ object functions {
def asin(columnName: String): Column = asin(Column(columnName))
/**
- * Computes the tangent inverse of the given column; the returned angle is in the range
- * -pi/2 through pi/2
+ * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan`
*
* @group math_funcs
* @since 1.4.0
@@ -1358,8 +1405,7 @@ object functions {
def atan(e: Column): Column = withExpr { Atan(e.expr) }
/**
- * Computes the tangent inverse of the given column; the returned angle is in the range
- * -pi/2 through pi/2
+ * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan`
*
* @group math_funcs
* @since 1.4.0
@@ -1367,77 +1413,117 @@ object functions {
def atan(columnName: String): Column = atan(Column(columnName))
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta). Units in radians.
+ * @param y coordinate on y-axis
+ * @param x coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) }
+ def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) }
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param y coordinate on y-axis
+ * @param xName coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName))
+ def atan2(y: Column, xName: String): Column = atan2(y, Column(xName))
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param yName coordinate on y-axis
+ * @param x coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r)
+ def atan2(yName: String, x: Column): Column = atan2(Column(yName), x)
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param yName coordinate on y-axis
+ * @param xName coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(leftName: String, rightName: String): Column =
- atan2(Column(leftName), Column(rightName))
+ def atan2(yName: String, xName: String): Column =
+ atan2(Column(yName), Column(xName))
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param y coordinate on y-axis
+ * @param xValue coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(l: Column, r: Double): Column = atan2(l, lit(r))
+ def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue))
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param yName coordinate on y-axis
+ * @param xValue coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r)
+ def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue)
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param yValue coordinate on y-axis
+ * @param x coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(l: Double, r: Column): Column = atan2(lit(l), r)
+ def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x)
/**
- * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
+ * @param yValue coordinate on y-axis
+ * @param xName coordinate on x-axis
+ * @return the theta component of the point
+ * (r, theta)
+ * in polar coordinates that corresponds to the point
+ * (x, y) in Cartesian coordinates,
+ * as if computed by `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
*/
- def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName))
+ def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName))
/**
* An expression that returns the string representation of the binary value of the given long
@@ -1500,7 +1586,8 @@ object functions {
}
/**
- * Computes the cosine of the given value. Units in radians.
+ * @param e angle in radians
+ * @return cosine of the angle, as if computed by `java.lang.Math.cos`
*
* @group math_funcs
* @since 1.4.0
@@ -1508,7 +1595,8 @@ object functions {
def cos(e: Column): Column = withExpr { Cos(e.expr) }
/**
- * Computes the cosine of the given column.
+ * @param columnName angle in radians
+ * @return cosine of the angle, as if computed by `java.lang.Math.cos`
*
* @group math_funcs
* @since 1.4.0
@@ -1516,7 +1604,8 @@ object functions {
def cos(columnName: String): Column = cos(Column(columnName))
/**
- * Computes the hyperbolic cosine of the given value.
+ * @param e hyperbolic angle
+ * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
*
* @group math_funcs
* @since 1.4.0
@@ -1524,7 +1613,8 @@ object functions {
def cosh(e: Column): Column = withExpr { Cosh(e.expr) }
/**
- * Computes the hyperbolic cosine of the given column.
+ * @param columnName hyperbolic angle
+ * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
*
* @group math_funcs
* @since 1.4.0
@@ -1967,7 +2057,8 @@ object functions {
def signum(columnName: String): Column = signum(Column(columnName))
/**
- * Computes the sine of the given value. Units in radians.
+ * @param e angle in radians
+ * @return sine of the angle, as if computed by `java.lang.Math.sin`
*
* @group math_funcs
* @since 1.4.0
@@ -1975,7 +2066,8 @@ object functions {
def sin(e: Column): Column = withExpr { Sin(e.expr) }
/**
- * Computes the sine of the given column.
+ * @param columnName angle in radians
+ * @return sine of the angle, as if computed by `java.lang.Math.sin`
*
* @group math_funcs
* @since 1.4.0
@@ -1983,7 +2075,8 @@ object functions {
def sin(columnName: String): Column = sin(Column(columnName))
/**
- * Computes the hyperbolic sine of the given value.
+ * @param e hyperbolic angle
+ * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
*
* @group math_funcs
* @since 1.4.0
@@ -1991,7 +2084,8 @@ object functions {
def sinh(e: Column): Column = withExpr { Sinh(e.expr) }
/**
- * Computes the hyperbolic sine of the given column.
+ * @param columnName hyperbolic angle
+ * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
*
* @group math_funcs
* @since 1.4.0
@@ -1999,7 +2093,8 @@ object functions {
def sinh(columnName: String): Column = sinh(Column(columnName))
/**
- * Computes the tangent of the given value. Units in radians.
+ * @param e angle in radians
+ * @return tangent of the given value, as if computed by `java.lang.Math.tan`
*
* @group math_funcs
* @since 1.4.0
@@ -2007,7 +2102,8 @@ object functions {
def tan(e: Column): Column = withExpr { Tan(e.expr) }
/**
- * Computes the tangent of the given column.
+ * @param columnName angle in radians
+ * @return tangent of the given value, as if computed by `java.lang.Math.tan`
*
* @group math_funcs
* @since 1.4.0
@@ -2015,7 +2111,8 @@ object functions {
def tan(columnName: String): Column = tan(Column(columnName))
/**
- * Computes the hyperbolic tangent of the given value.
+ * @param e hyperbolic angle
+ * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
*
* @group math_funcs
* @since 1.4.0
@@ -2023,7 +2120,8 @@ object functions {
def tanh(e: Column): Column = withExpr { Tanh(e.expr) }
/**
- * Computes the hyperbolic tangent of the given column.
+ * @param columnName hyperbolic angle
+ * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
*
* @group math_funcs
* @since 1.4.0
@@ -2047,6 +2145,9 @@ object functions {
/**
* Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
*
+ * @param e angle in radians
+ * @return angle in degrees, as if computed by `java.lang.Math.toDegrees`
+ *
* @group math_funcs
* @since 2.1.0
*/
@@ -2055,6 +2156,9 @@ object functions {
/**
* Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
*
+ * @param columnName angle in radians
+ * @return angle in degrees, as if computed by `java.lang.Math.toDegrees`
+ *
* @group math_funcs
* @since 2.1.0
*/
@@ -2077,6 +2181,9 @@ object functions {
/**
* Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
*
+ * @param e angle in degrees
+ * @return angle in radians, as if computed by `java.lang.Math.toRadians`
+ *
* @group math_funcs
* @since 2.1.0
*/
@@ -2085,6 +2192,9 @@ object functions {
/**
* Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
*
+ * @param columnName angle in degrees
+ * @return angle in radians, as if computed by `java.lang.Math.toRadians`
+ *
* @group math_funcs
* @since 2.1.0
*/
@@ -2170,16 +2280,6 @@ object functions {
*/
def base64(e: Column): Column = withExpr { Base64(e.expr) }
- /**
- * Concatenates multiple input columns together into a single column.
- * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- @scala.annotation.varargs
- def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
-
/**
* Concatenates multiple input string columns together into a single string column,
* using the given separator.
@@ -2406,14 +2506,6 @@ object functions {
StringRepeat(str.expr, lit(n).expr)
}
- /**
- * Reverses the string column and returns it as a new string column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }
-
/**
* Trim the spaces from right end for the specified string value.
*
@@ -2651,11 +2743,27 @@ object functions {
/**
* Returns number of months between dates `date1` and `date2`.
+ * If `date1` is later than `date2`, then the result is positive.
+ * If `date1` and `date2` are on the same day of month, or both are the last day of month,
+ * time of day will be ignored.
+ *
+ * Otherwise, the difference is calculated based on 31 days per month, and rounded to
+ * 8 digits.
* @group datetime_funcs
* @since 1.5.0
*/
def months_between(date1: Column, date2: Column): Column = withExpr {
- MonthsBetween(date1.expr, date2.expr)
+ new MonthsBetween(date1.expr, date2.expr)
+ }
+
+ /**
+ * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the
+ * result is rounded off to 8 digits; it is not rounded otherwise.
+ * @group datetime_funcs
+ * @since 2.4.0
+ */
+ def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr {
+ MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr)
}
/**
@@ -2873,7 +2981,7 @@ object functions {
* or equal to the `windowDuration`. Check
* `org.apache.spark.unsafe.types.CalendarInterval` for valid duration
* identifiers. This duration is likewise absolute, and does not vary
- * according to a calendar.
+ * according to a calendar.
* @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start
* window intervals. For example, in order to have hourly tumbling windows that
* start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
@@ -2929,7 +3037,7 @@ object functions {
* or equal to the `windowDuration`. Check
* `org.apache.spark.unsafe.types.CalendarInterval` for valid duration
* identifiers. This duration is likewise absolute, and does not vary
- * according to a calendar.
+ * according to a calendar.
*
* @group datetime_funcs
* @since 2.0.0
@@ -2988,6 +3096,99 @@ object functions {
ArrayContains(column.expr, Literal(value))
}
+ /**
+ * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both
+ * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns
+ * `false` otherwise.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def arrays_overlap(a1: Column, a2: Column): Column = withExpr {
+ ArraysOverlap(a1.expr, a2.expr)
+ }
+
+ /**
+ * Returns an array containing all the elements in `x` from index `start` (or starting from the
+ * end if `start` is negative) with the specified `length`.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def slice(x: Column, start: Int, length: Int): Column = withExpr {
+ Slice(x.expr, Literal(start), Literal(length))
+ }
+
+ /**
+ * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+ * `nullReplacement`.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr {
+ ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement)))
+ }
+
+ /**
+ * Concatenates the elements of `column` using the `delimiter`.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_join(column: Column, delimiter: String): Column = withExpr {
+ ArrayJoin(column.expr, Literal(delimiter), None)
+ }
+
+ /**
+ * Concatenates multiple input columns together into a single column.
+ * The function works with strings, binary and compatible array columns.
+ *
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
+
+ /**
+ * Locates the position of the first occurrence of the value in the given array as long.
+ * Returns null if either of the arguments are null.
+ *
+ * @note The position is not zero based, but 1 based index. Returns 0 if value
+ * could not be found in array.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_position(column: Column, value: Any): Column = withExpr {
+ ArrayPosition(column.expr, Literal(value))
+ }
+
+ /**
+ * Returns element of array at given index in value if column is array. Returns value for
+ * the given key in value if column is map.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def element_at(column: Column, value: Any): Column = withExpr {
+ ElementAt(column.expr, Literal(value))
+ }
+
+ /**
+ * Sorts the input array in ascending order. The elements of the input array must be orderable.
+ * Null elements will be placed at the end of the returned array.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) }
+
+ /**
+ * Remove all elements that equal to element from the given array.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_remove(column: Column, element: Any): Column = withExpr {
+ ArrayRemove(column.expr, Literal(element))
+ }
+
/**
* Creates a new row for each element in the given array or map column.
*
@@ -3061,9 +3262,9 @@ object functions {
from_json(e, schema.asInstanceOf[DataType], options)
/**
- * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3074,7 +3275,7 @@ object functions {
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
- JsonToStructs(schema, options, e.expr)
+ new JsonToStructs(schema, options, e.expr)
}
/**
@@ -3093,9 +3294,9 @@ object functions {
from_json(e, schema, options.asScala.toMap)
/**
- * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3122,8 +3323,9 @@ object functions {
from_json(e, schema, Map.empty[String, String])
/**
- * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s
- * with the specified schema. Returns `null`, in the case of an unparseable string.
+ * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type,
+ * `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
@@ -3135,9 +3337,9 @@ object functions {
from_json(e, schema, Map.empty[String, String])
/**
- * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string. In Spark 2.1,
@@ -3152,9 +3354,9 @@ object functions {
}
/**
- * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
- * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable
- * string.
+ * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string, it could be a
@@ -3167,7 +3369,7 @@ object functions {
val dataType = try {
DataType.fromJson(schema)
} catch {
- case NonFatal(_) => StructType.fromDDL(schema)
+ case NonFatal(_) => DataType.fromDDL(schema)
}
from_json(e, dataType, options)
}
@@ -3227,6 +3429,7 @@ object functions {
/**
* Sorts the input array for the given column in ascending order,
* according to the natural ordering of the array elements.
+ * Null elements will be placed at the beginning of the returned array.
*
* @group collection_funcs
* @since 1.5.0
@@ -3236,12 +3439,65 @@ object functions {
/**
* Sorts the input array for the given column in ascending or descending order,
* according to the natural ordering of the array elements.
+ * Null elements will be placed at the beginning of the returned array in ascending order or
+ * at the end of the returned array in descending order.
*
* @group collection_funcs
* @since 1.5.0
*/
def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
+ /**
+ * Returns the minimum value in the array.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) }
+
+ /**
+ * Returns the maximum value in the array.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }
+
+ /**
+ * Returns a reversed string or an array with reverse order of elements.
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
+
+ /**
+ * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than
+ * two levels, only one level of nesting is removed.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
+
+ /**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(left: Column, right: Column): Column = withExpr {
+ ArrayRepeat(left.expr, right.expr)
+ }
+
+ /**
+ * Creates an array containing the left argument repeated the number of times given by the
+ * right argument.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count))
+
/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
@@ -3256,6 +3512,140 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }
+ /**
+ * Returns an unordered array of all entries in the given map.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
+
+ /**
+ * Returns a merged array of structs in which the N-th struct contains all N-th values of input
+ * arrays.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Mask functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ /**
+ * Returns a string which is the masked representation of the input.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask(e: Column): Column = withExpr { new Mask(e.expr) }
+
+ /**
+ * Returns a string which is the masked representation of the input, using `upper`, `lower` and
+ * `digit` as replacement characters.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr {
+ Mask(e.expr, upper, lower, digit)
+ }
+
+ /**
+ * Returns a string with the first `n` characters masked.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) }
+
+ /**
+ * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as
+ * replacement characters.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_first_n(
+ e: Column,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String): Column = withExpr {
+ MaskFirstN(e.expr, n, upper, lower, digit)
+ }
+
+ /**
+ * Returns a string with the last `n` characters masked.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) }
+
+ /**
+ * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as
+ * replacement characters.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_last_n(
+ e: Column,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String): Column = withExpr {
+ MaskLastN(e.expr, n, upper, lower, digit)
+ }
+
+ /**
+ * Returns a string with all but the first `n` characters masked.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_show_first_n(e: Column, n: Int): Column = withExpr {
+ new MaskShowFirstN(e.expr, Literal(n))
+ }
+
+ /**
+ * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and
+ * `digit` as replacement characters.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_show_first_n(
+ e: Column,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String): Column = withExpr {
+ MaskShowFirstN(e.expr, n, upper, lower, digit)
+ }
+
+ /**
+ * Returns a string with all but the last `n` characters masked.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_show_last_n(e: Column, n: Int): Column = withExpr {
+ new MaskShowLastN(e.expr, Literal(n))
+ }
+
+ /**
+ * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and
+ * `digit` as replacement characters.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_show_last_n(
+ e: Column,
+ n: Int,
+ upper: String,
+ lower: String,
+ digit: String): Column = withExpr {
+ MaskShowLastN(e.expr, n, upper, lower, digit)
+ }
+
+ /**
+ * Returns a hashed value based on the input column.
+ * @group mask_funcs
+ * @since 2.4.0
+ */
+ def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) }
+
// scalastyle:off line.size.limit
// scalastyle:off parameter.number
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 007f8760edf82..3a0db7e16c23a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder(
*/
protected lazy val catalog: SessionCatalog = {
val catalog = new SessionCatalog(
- session.sharedState.externalCatalog,
- session.sharedState.globalTempViewManager,
+ () => session.sharedState.externalCatalog,
+ () => session.sharedState.globalTempViewManager,
functionRegistry,
conf,
SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
index dac463641cfab..eca612f06f9bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala
@@ -31,7 +31,8 @@ object HiveSerDe {
"sequencefile" ->
HiveSerDe(
inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"),
- outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")),
+ outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"),
+ serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")),
"rcfile" ->
HiveSerDe(
@@ -54,7 +55,8 @@ object HiveSerDe {
"textfile" ->
HiveSerDe(
inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
+ outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"),
+ serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")),
"avro" ->
HiveSerDe(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index baea4ceebf8e3..5b6160e2b408f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
/**
* A catalog that interacts with external systems.
*/
- lazy val externalCatalog: ExternalCatalog = {
+ lazy val externalCatalog: ExternalCatalogWithListener = {
val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration](
SharedState.externalCatalogClassName(sparkContext.conf),
sparkContext.conf,
@@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true)
}
+ // Wrap to provide catalog events
+ val wrapped = new ExternalCatalogWithListener(externalCatalog)
+
// Make sure we propagate external catalog events to the spark listener bus
- externalCatalog.addListener(new ExternalCatalogEventListener {
+ wrapped.addListener(new ExternalCatalogEventListener {
override def onEvent(event: ExternalCatalogEvent): Unit = {
sparkContext.listenerBus.post(event)
}
})
- externalCatalog
+ wrapped
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 116ac3da07b75..ef8dc3a325a33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -173,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
}
ds match {
case s: MicroBatchReadSupport =>
- val tempReader = s.createMicroBatchReader(
- Optional.ofNullable(userSpecifiedSchema.orNull),
- Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
- options)
+ var tempReader: MicroBatchReader = null
+ val schema = try {
+ tempReader = s.createMicroBatchReader(
+ Optional.ofNullable(userSpecifiedSchema.orNull),
+ Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
+ options)
+ tempReader.readSchema()
+ } finally {
+ // Stop tempReader to avoid side-effect thing
+ if (tempReader != null) {
+ tempReader.stop()
+ tempReader = null
+ }
+ }
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
s, source, extraOptions.toMap,
- tempReader.readSchema().toAttributes, v1Relation)(sparkSession))
+ schema.toAttributes, v1Relation)(sparkSession))
case s: ContinuousReadSupport =>
val tempReader = s.createContinuousReader(
Optional.ofNullable(userSpecifiedSchema.orNull),
@@ -237,12 +247,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
*
- *
`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
- * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
- * in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord`
- * field in an output schema.
+ *
`PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
+ * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To
+ * keep corrupt records, an user can set a string type field named
+ * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the
+ * field, it drops corrupt records during parsing. When inferring a schema, it implicitly
+ * adds a `columnNameOfCorruptRecord` field in an output schema.
*
`DROPMALFORMED` : ignores the whole corrupted records.
*
`FAILFAST` : throws an exception when it meets corrupted records.
*
@@ -258,6 +268,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `java.text.SimpleDateFormat`. This applies to timestamp type.
*
`multiLine` (default `false`): parse one record, which may span multiple lines,
* per file
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
+ *
`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
+ * empty array/struct during schema inference.
*
*
* @since 2.0.0
@@ -317,12 +331,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing. It supports the following case-insensitive modes.
*
- *
`PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts
- * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep
+ *
`PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
+ * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep
* corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
* in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. When a length of parsed CSV tokens is shorter than an expected length
- * of a schema, it sets `null` for extra fields.
+ * during parsing. A record with less/more tokens than schema is not a corrupted record to
+ * CSV. When it meets a record having fewer tokens than the length of the schema, sets
+ * `null` to extra fields. When the record has more tokens than the length of the schema,
+ * it drops extra tokens.
*
`DROPMALFORMED` : ignores the whole corrupted records.
*
`FAILFAST` : throws an exception when it meets corrupted records.
*
@@ -375,7 +391,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* Loads text files and returns a `DataFrame` whose schema starts with a string column named
* "value", and followed by partitioned columns if there are any.
*
- * Each line in the text files is a new row in the resulting DataFrame. For example:
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
* {{{
* // Scala:
* spark.readStream.text("/path/to/directory/")
@@ -388,6 +404,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
*
`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
* considered in every trigger.
+ *
`wholetext` (default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @since 2.0.0
@@ -401,7 +421,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* If the directory structure of the text files contains partitioning information, those are
* ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
*
- * Each line in the text file is a new element in the resulting Dataset. For example:
+ * By default, each line in the text file is a new element in the resulting Dataset. For example:
* {{{
* // Scala:
* spark.readStream.textFile("/path/to/spark/README.md")
@@ -414,6 +434,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
*
*
`maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
* considered in every trigger.
+ *
`wholetext` (default `false`): If true, read a file as a single row and not split by "\n".
+ *
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing.
*
*
* @param path input path
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 9aac360fd4bbc..43e80e4e54239 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
-import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport
+import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
(s, r)
case _ =>
- val s = new MemorySink(df.schema, outputMode)
+ val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
(s, r)
}
@@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
- val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc)
+ val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
@@ -307,49 +307,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}
/**
- * Starts the execution of the streaming query, which will continually send results to the given
- * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data
- * generated by the `DataFrame`/`Dataset` to an external system.
- *
- * Scala example:
- * {{{
- * datasetOfString.writeStream.foreach(new ForeachWriter[String] {
- *
- * def open(partitionId: Long, version: Long): Boolean = {
- * // open connection
- * }
- *
- * def process(record: String) = {
- * // write string to connection
- * }
- *
- * def close(errorOrNull: Throwable): Unit = {
- * // close the connection
- * }
- * }).start()
- * }}}
- *
- * Java example:
- * {{{
- * datasetOfString.writeStream().foreach(new ForeachWriter() {
- *
- * @Override
- * public boolean open(long partitionId, long version) {
- * // open connection
- * }
- *
- * @Override
- * public void process(String value) {
- * // write string to connection
- * }
- *
- * @Override
- * public void close(Throwable errorOrNull) {
- * // close the connection
- * }
- * }).start();
- * }}}
- *
+ * Sets the output of the streaming query to be processed using the provided writer object.
+ * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and
+ * semantics.
* @since 2.0.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index ddb1edc433d5a..25bb05212d66f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
@@ -32,7 +33,8 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger}
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport
+import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS
+import org.apache.spark.sql.sources.v2.StreamWriteSupport
import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
@@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
@GuardedBy("awaitTerminationLock")
private var lastTerminatedQuery: StreamingQuery = null
+ try {
+ sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames =>
+ Utils.loadExtensions(classOf[StreamingQueryListener], classNames,
+ sparkSession.sparkContext.conf).foreach(listener => {
+ addListener(listener)
+ logInfo(s"Registered listener ${listener.getClass.getName}")
+ })
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Exception when registering StreamingQueryListener", e)
+ }
+
/**
* Returns a list of active queries associated with this SQLContext
*
@@ -242,7 +257,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
(sink, trigger) match {
case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) =>
- UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
+ if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
+ UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
+ }
new StreamingQueryWrapper(new ContinuousExecution(
sparkSession,
userSpecifiedName.orNull,
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index c132cab1b38cf..2c695fc58fd8c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -34,6 +34,7 @@
import org.junit.*;
import org.junit.rules.ExpectedException;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.sql.*;
@@ -336,6 +337,23 @@ public void testTupleEncoder() {
Assert.assertEquals(data5, ds5.collectAsList());
}
+ @Test
+ public void testTupleEncoderSchema() {
+ Encoder>> encoder =
+ Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING()));
+ List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")),
+ tuple2("2", tuple2("c", "d")));
+ Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2");
+
+ JavaPairRDD> pairRDD = jsc.parallelizePairs(data);
+ Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder)
+ .toDF("value1", "value2");
+
+ Assert.assertEquals(ds1.schema(), ds2.schema());
+ Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(),
+ ds2.select(expr("value2._1")).collectAsList());
+ }
+
@Test
public void testNestedTupleEncoder() {
// test ((int, string), string)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
index ddbaa45a483cb..08dc129f27a0c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
@@ -46,7 +46,7 @@ public void tearDown() {
@SuppressWarnings("unchecked")
@Test
public void udf1Test() {
- spark.range(1, 10).toDF("value").registerTempTable("df");
+ spark.range(1, 10).toDF("value").createOrReplaceTempView("df");
spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName());
Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head();
Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index 172e5d5eebcbe..445cb29f5ee3a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -79,8 +79,8 @@ public Filter[] pushedFilters() {
}
@Override
- public List> createDataReaderFactories() {
- List> res = new ArrayList<>();
+ public List> planInputPartitions() {
+ List> res = new ArrayList<>();
Integer lowerBound = null;
for (Filter filter : filters) {
@@ -94,33 +94,34 @@ public List> createDataReaderFactories() {
}
if (lowerBound == null) {
- res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema));
- res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 4) {
- res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema));
- res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 9) {
- res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema));
+ res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema));
}
return res;
}
}
- static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaAdvancedInputPartition implements InputPartition,
+ InputPartitionReader {
private int start;
private int end;
private StructType requiredSchema;
- JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) {
+ JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) {
this.start = start;
this.end = end;
this.requiredSchema = requiredSchema;
}
@Override
- public DataReader createDataReader() {
- return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaAdvancedInputPartition(start - 1, end, requiredSchema);
}
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
index c55093768105b..97d6176d02559 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
@@ -42,14 +42,14 @@ public StructType readSchema() {
}
@Override
- public List> createBatchDataReaderFactories() {
+ public List> planBatchInputPartitions() {
return java.util.Arrays.asList(
- new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90));
+ new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90));
}
}
- static class JavaBatchDataReaderFactory
- implements DataReaderFactory, DataReader {
+ static class JavaBatchInputPartition
+ implements InputPartition, InputPartitionReader {
private int start;
private int end;
@@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory
private OnHeapColumnVector j;
private ColumnarBatch batch;
- JavaBatchDataReaderFactory(int start, int end) {
+ JavaBatchInputPartition(int start, int end) {
this.start = start;
this.end = end;
}
@Override
- public DataReader createDataReader() {
+ public InputPartitionReader createPartitionReader() {
this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
ColumnVector[] vectors = new ColumnVector[2];
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
index 32fad59b97ff6..e49c8cf8b9e16 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -43,10 +43,10 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Arrays.asList(
- new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
- new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
+ new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
+ new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
}
@Override
@@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) {
}
}
- static class SpecificDataReaderFactory implements DataReaderFactory, DataReader {
+ static class SpecificInputPartition implements InputPartition, InputPartitionReader {
private int[] i;
private int[] j;
private int current = -1;
- SpecificDataReaderFactory(int[] i, int[] j) {
+ SpecificInputPartition(int[] i, int[] j) {
assert i.length == j.length;
this.i = i;
this.j = j;
@@ -101,7 +101,7 @@ public void close() throws IOException {
}
@Override
- public DataReader createDataReader() {
+ public InputPartitionReader createPartitionReader() {
return this;
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
index 048d078dfaac4..80eeffd95f83b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
@@ -24,7 +24,7 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.types.StructType;
public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema {
@@ -42,7 +42,7 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Collections.emptyList();
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
index 96f55b8a76811..8522a63898a3b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
@@ -25,8 +25,8 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
-import org.apache.spark.sql.sources.v2.reader.DataReader;
-import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
+import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.types.StructType;
@@ -41,25 +41,25 @@ public StructType readSchema() {
}
@Override
- public List> createDataReaderFactories() {
+ public List> planInputPartitions() {
return java.util.Arrays.asList(
- new JavaSimpleDataReaderFactory(0, 5),
- new JavaSimpleDataReaderFactory(5, 10));
+ new JavaSimpleInputPartition(0, 5),
+ new JavaSimpleInputPartition(5, 10));
}
}
- static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader {
+ static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader {
private int start;
private int end;
- JavaSimpleDataReaderFactory(int start, int end) {
+ JavaSimpleInputPartition(int start, int end) {
this.start = start;
this.end = end;
}
@Override
- public DataReader createDataReader() {
- return new JavaSimpleDataReaderFactory(start - 1, end);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaSimpleInputPartition(start - 1, end);
}
@Override
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
index c3916e0b370b5..3ad8e7a0104ce 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java
@@ -38,20 +38,20 @@ public StructType readSchema() {
}
@Override
- public List> createUnsafeRowReaderFactories() {
+ public List> planUnsafeInputPartitions() {
return java.util.Arrays.asList(
- new JavaUnsafeRowDataReaderFactory(0, 5),
- new JavaUnsafeRowDataReaderFactory(5, 10));
+ new JavaUnsafeRowInputPartition(0, 5),
+ new JavaUnsafeRowInputPartition(5, 10));
}
}
- static class JavaUnsafeRowDataReaderFactory
- implements DataReaderFactory, DataReader {
+ static class JavaUnsafeRowInputPartition
+ implements InputPartition, InputPartitionReader {
private int start;
private int end;
private UnsafeRow row;
- JavaUnsafeRowDataReaderFactory(int start, int end) {
+ JavaUnsafeRowInputPartition(int start, int end) {
this.start = start;
this.end = end;
this.row = new UnsafeRow(2);
@@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory
}
@Override
- public DataReader createDataReader() {
- return new JavaUnsafeRowDataReaderFactory(start - 1, end);
+ public InputPartitionReader createPartitionReader() {
+ return new JavaUnsafeRowInputPartition(start - 1, end);
}
@Override
diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
index ad0f885f63d3d..2909024e4c9f7 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql
@@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column
-- Change column in partition spec (not supported yet)
CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d);
ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT;
+ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C';
-- DROP TEST TABLE
DROP TABLE test_change;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
index adea2bfa82cd3..4950a4b7a4e5a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql
@@ -25,3 +25,38 @@ create temporary view ttf2 as select * from values
select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2;
select a, b from ttf2 order by a, current_date;
+
+select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15');
+
+select from_utc_timestamp('2015-07-24 00:00:00', 'PST');
+
+select from_utc_timestamp('2015-01-24 00:00:00', 'PST');
+
+select from_utc_timestamp(null, 'PST');
+
+select from_utc_timestamp('2015-07-24 00:00:00', null);
+
+select from_utc_timestamp(null, null);
+
+select from_utc_timestamp(cast(0 as timestamp), 'PST');
+
+select from_utc_timestamp(cast('2015-01-24' as date), 'PST');
+
+select to_utc_timestamp('2015-07-24 00:00:00', 'PST');
+
+select to_utc_timestamp('2015-01-24 00:00:00', 'PST');
+
+select to_utc_timestamp(null, 'PST');
+
+select to_utc_timestamp('2015-07-24 00:00:00', null);
+
+select to_utc_timestamp(null, null);
+
+select to_utc_timestamp(cast(0 as timestamp), 'PST');
+
+select to_utc_timestamp(cast('2015-01-24' as date), 'PST');
+
+-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone
+select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST');
+
+select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
index 72f05f49f1619..35f2be46cd130 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql
@@ -40,12 +40,14 @@ select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;
+select 2.35E10 * 1.0;
-- arithmetic operations causing an overflow return NULL
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;
+select 1.2345678901234567890E30 * 1.2345678901234567890E25;
-- arithmetic operations causing a precision loss are truncated
select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345;
@@ -67,12 +69,14 @@ select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;
+select 2.35E10 * 1.0;
-- arithmetic operations causing an overflow return NULL
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;
+select 1.2345678901234567890E30 * 1.2345678901234567890E25;
-- arithmetic operations causing a precision loss return NULL
select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql
new file mode 100644
index 0000000000000..9adf5d70056e2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql
@@ -0,0 +1,21 @@
+CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c;
+
+select extract(year from c) from t;
+
+select extract(quarter from c) from t;
+
+select extract(month from c) from t;
+
+select extract(week from c) from t;
+
+select extract(day from c) from t;
+
+select extract(dayofweek from c) from t;
+
+select extract(hour from c) from t;
+
+select extract(minute from c) from t;
+
+select extract(second from c) from t;
+
+select extract(not_supported from c) from t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index c5070b734d521..2c18d6aaabdba 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -68,4 +68,8 @@ SELECT 1 from (
FROM (select 1 as x) a
WHERE false
) b
-where b.z != b.z
+where b.z != b.z;
+
+-- SPARK-24369 multiple distinct aggregations having the same argument set
+SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
+ FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
index fea069eac4d48..dc15d13cd1dd3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
@@ -31,3 +31,7 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1,
SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable;
-- Clean up
DROP VIEW IF EXISTS jsonTable;
+
+-- from_json - complex types
+select from_json('{"a":1, "b":2}', 'map');
+select from_json('{"a":1, "b":"2"}', 'struct');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
index 37b4b7606d12b..a743cf1ec2cde 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
@@ -105,3 +105,6 @@ select X'XuZ';
-- Hive literal_double test.
SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8;
+
+-- map + interval test
+select map(1, interval 1 day, 2, interval 3 week);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
new file mode 100644
index 0000000000000..01dea6c81c11b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
@@ -0,0 +1,113 @@
+create temporary view courseSales as select * from values
+ ("dotNET", 2012, 10000),
+ ("Java", 2012, 20000),
+ ("dotNET", 2012, 5000),
+ ("dotNET", 2013, 48000),
+ ("Java", 2013, 30000)
+ as courseSales(course, year, earnings);
+
+create temporary view years as select * from values
+ (2012, 1),
+ (2013, 2)
+ as years(y, s);
+
+-- pivot courses
+SELECT * FROM (
+ SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+ sum(earnings)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot years with no subquery
+SELECT * FROM courseSales
+PIVOT (
+ sum(earnings)
+ FOR year IN (2012, 2013)
+);
+
+-- pivot courses with multiple aggregations
+SELECT * FROM (
+ SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+ sum(earnings), avg(earnings)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot with no group by column
+SELECT * FROM (
+ SELECT course, earnings FROM courseSales
+)
+PIVOT (
+ sum(earnings)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot with no group by column and with multiple aggregations on different columns
+SELECT * FROM (
+ SELECT year, course, earnings FROM courseSales
+)
+PIVOT (
+ sum(earnings), min(year)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot on join query with multiple group by columns
+SELECT * FROM (
+ SELECT course, year, earnings, s
+ FROM courseSales
+ JOIN years ON year = y
+)
+PIVOT (
+ sum(earnings)
+ FOR s IN (1, 2)
+);
+
+-- pivot on join query with multiple aggregations on different columns
+SELECT * FROM (
+ SELECT course, year, earnings, s
+ FROM courseSales
+ JOIN years ON year = y
+)
+PIVOT (
+ sum(earnings), min(s)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot on join query with multiple columns in one aggregation
+SELECT * FROM (
+ SELECT course, year, earnings, s
+ FROM courseSales
+ JOIN years ON year = y
+)
+PIVOT (
+ sum(earnings * s)
+ FOR course IN ('dotNET', 'Java')
+);
+
+-- pivot with aliases and projection
+SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM (
+ SELECT year y, course c, earnings e FROM courseSales
+)
+PIVOT (
+ sum(e) s, avg(e) a
+ FOR y IN (2012, 2013)
+);
+
+-- pivot years with non-aggregate function
+SELECT * FROM courseSales
+PIVOT (
+ abs(earnings)
+ FOR year IN (2012, 2013)
+);
+
+-- pivot with unresolvable columns
+SELECT * FROM (
+ SELECT course, earnings FROM courseSales
+)
+PIVOT (
+ sum(earnings)
+ FOR year IN (2012, 2013)
+);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql
index e99d5cef81f64..fadb4bb27fa13 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql
@@ -39,3 +39,10 @@ select 2.0 <= '2.2';
select 0.5 <= '1.5';
select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52');
select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52';
+
+-- SPARK-23549: Cast to timestamp when comparing timestamp with date
+select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00');
+select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01');
+select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01');
+select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01');
+select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql
new file mode 100644
index 0000000000000..8eea84f4f5272
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql
@@ -0,0 +1,39 @@
+-- Unit tests for simple NOT IN predicate subquery across multiple columns.
+--
+-- See not-in-single-column-unit-tests.sql for an introduction.
+-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of
+-- subqueries. Small changes have been made to the literals to make them typecheck.
+
+CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES
+ (null, null),
+ (null, 1.0),
+ (2, 3.0),
+ (4, 5.0)
+ AS m(a, b);
+
+-- Case 1 (not possible to write a literal with no rows, so we ignore it.)
+-- (subquery is empty -> row is returned)
+
+-- Cases 2, 3 and 4 are currently broken, so I have commented them out here.
+-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases.
+
+ -- Case 5
+ -- (one null column with no match -> row is returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Matches (null, 1.0)
+ AND (a, b) NOT IN ((2, 3.0));
+
+ -- Case 6
+ -- (no null columns with match -> row not returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Matches (2, 3.0)
+ AND (a, b) NOT IN ((2, 3.0));
+
+ -- Case 7
+ -- (no null columns with no match -> row is returned)
+SELECT *
+FROM m
+WHERE b = 5.0 -- Matches (4, 5.0)
+ AND (a, b) NOT IN ((2, 3.0));
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql
new file mode 100644
index 0000000000000..9f8dc7fca3b94
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql
@@ -0,0 +1,98 @@
+-- Unit tests for simple NOT IN predicate subquery across multiple columns.
+--
+-- See not-in-single-column-unit-tests.sql for an introduction.
+--
+-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'':
+-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? |
+-- | 1 | empty | * | * | * | yes |
+-- | 2 | 1+ row has null for all columns | * | * | * | no |
+-- | 3 | no row has null for all columns | (yes, yes) | * | * | no |
+-- | 4 | no row has null for all columns | (no, yes) | yes | * | no |
+-- | 5 | no row has null for all columns | (no, yes) | no | * | yes |
+-- | 6 | no | (no, no) | yes | yes | no |
+-- | 7 | no | (no, no) | _ | _ | yes |
+--
+-- This can be generalized to include more tests for more columns, but it covers the main cases
+-- when there is more than one column.
+
+CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES
+ (null, null),
+ (null, 1.0),
+ (2, 3.0),
+ (4, 5.0)
+ AS m(a, b);
+
+CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES
+ (null, null),
+ (0, 1.0),
+ (2, 3.0),
+ (4, null)
+ AS s(c, d);
+
+ -- Case 1
+ -- (subquery is empty -> row is returned)
+SELECT *
+FROM m
+WHERE (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE d > 5.0) -- Matches no rows
+;
+
+ -- Case 2
+ -- (subquery contains a row with null in all columns -> row not returned)
+SELECT *
+FROM m
+WHERE (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c IS NULL AND d IS NULL) -- Matches only (null, null)
+;
+
+ -- Case 3
+ -- (probe-side columns are all null -> row not returned)
+SELECT *
+FROM m
+WHERE a IS NULL AND b IS NULL -- Matches only (null, null)
+ AND (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null)
+;
+
+ -- Case 4
+ -- (one column null, other column matches a row in the subquery result -> row not returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Matches (null, 1.0)
+ AND (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null)
+;
+
+ -- Case 5
+ -- (one null column with no match -> row is returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Matches (null, 1.0)
+ AND (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c = 2) -- Matches (2, 3.0)
+;
+
+ -- Case 6
+ -- (no null columns with match -> row not returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Matches (2, 3.0)
+ AND (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c = 2) -- Matches (2, 3.0)
+;
+
+ -- Case 7
+ -- (no null columns with no match -> row is returned)
+SELECT *
+FROM m
+WHERE b = 5.0 -- Matches (4, 5.0)
+ AND (a, b) NOT IN (SELECT *
+ FROM s
+ WHERE c = 2) -- Matches (2, 3.0)
+;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql
new file mode 100644
index 0000000000000..b261363d1dde7
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql
@@ -0,0 +1,42 @@
+-- Unit tests for simple NOT IN with a literal expression of a single column
+--
+-- More information can be found in not-in-unit-tests-single-column.sql.
+-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of
+-- subqueries.
+
+CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES
+ (null, 1.0),
+ (2, 3.0),
+ (4, 5.0)
+ AS m(a, b);
+
+ -- Uncorrelated NOT IN Subquery test cases
+ -- Case 1 (not possible to write a literal with no rows, so we ignore it.)
+ -- (empty subquery -> all rows returned)
+
+ -- Case 2
+ -- (subquery includes null -> no rows returned)
+SELECT *
+FROM m
+WHERE a NOT IN (null);
+
+ -- Case 3
+ -- (probe column is null -> row not returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Only matches (null, 1.0)
+ AND a NOT IN (2);
+
+ -- Case 4
+ -- (probe column matches subquery row -> row not returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Only matches (2, 3.0)
+ AND a NOT IN (2);
+
+ -- Case 5
+ -- (probe column does not match subquery row -> row is returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Only matches (2, 3.0)
+ AND a NOT IN (6);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql
new file mode 100644
index 0000000000000..2cc08e10acf67
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql
@@ -0,0 +1,123 @@
+-- Unit tests for simple NOT IN predicate subquery across a single column.
+--
+-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the
+-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain
+-- unintuitive. To make this simpler to understand, I've come up with a plain English way of
+-- describing the expected behavior of this query.
+--
+-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of
+-- whether the filtered columns include nulls.
+-- - If the subquery contains a result with all columns null, then the row should not be returned.
+-- - If for all non-null filter columns there exists a row in the subquery in which each column
+-- either
+-- 1. is equal to the corresponding filter column or
+-- 2. is null
+-- then the row should not be returned. (This includes the case where all filter columns are
+-- null.)
+-- - Otherwise, the row should be returned.
+--
+-- Using these rules, we can come up with a set of test cases for single-column and multi-column
+-- NOT IN test cases.
+--
+-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'':
+-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? |
+-- | 1 | empty | | | yes |
+-- | 2 | yes | | | no |
+-- | 3 | no | yes | | no |
+-- | 4 | no | no | yes | no |
+-- | 5 | no | no | no | yes |
+--
+-- There are also some considerations around correlated subqueries. Correlated subqueries can
+-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the
+-- subquery, so the row from the parent table should always be included in the output.
+
+CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES
+ (null, 1.0),
+ (2, 3.0),
+ (4, 5.0)
+ AS m(a, b);
+
+CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES
+ (null, 1.0),
+ (2, 3.0),
+ (6, 7.0)
+ AS s(c, d);
+
+ -- Uncorrelated NOT IN Subquery test cases
+ -- Case 1
+ -- (empty subquery -> all rows returned)
+SELECT *
+FROM m
+WHERE a NOT IN (SELECT c
+ FROM s
+ WHERE d > 10.0) -- (empty subquery)
+;
+
+ -- Case 2
+ -- (subquery includes null -> no rows returned)
+SELECT *
+FROM m
+WHERE a NOT IN (SELECT c
+ FROM s
+ WHERE d = 1.0) -- Only matches (null, 1.0)
+;
+
+ -- Case 3
+ -- (probe column is null -> row not returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Only matches (null, 1.0)
+ AND a NOT IN (SELECT c
+ FROM s
+ WHERE d = 3.0) -- Matches (2, 3.0)
+;
+
+ -- Case 4
+ -- (probe column matches subquery row -> row not returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Only matches (2, 3.0)
+ AND a NOT IN (SELECT c
+ FROM s
+ WHERE d = 3.0) -- Matches (2, 3.0)
+;
+
+ -- Case 5
+ -- (probe column does not match subquery row -> row is returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Only matches (2, 3.0)
+ AND a NOT IN (SELECT c
+ FROM s
+ WHERE d = 7.0) -- Matches (6, 7.0)
+;
+
+ -- Correlated NOT IN subquery test cases
+ -- Case 2->1
+ -- (subquery had nulls but they are removed by correlated subquery -> all rows returned)
+SELECT *
+FROM m
+WHERE a NOT IN (SELECT c
+ FROM s
+ WHERE d = b + 10) -- Matches no row
+;
+
+ -- Case 3->1
+ -- (probe column is null but subquery returns no rows -> row is returned)
+SELECT *
+FROM m
+WHERE b = 1.0 -- Only matches (null, 1.0)
+ AND a NOT IN (SELECT c
+ FROM s
+ WHERE d = b + 10) -- Matches no row
+;
+
+ -- Case 4->1
+ -- (probe column matches row which is filtered out by correlated subquery -> row is returned)
+SELECT *
+FROM m
+WHERE b = 3.0 -- Only matches (2, 3.0)
+ AND a NOT IN (SELECT c
+ FROM s
+ WHERE d = b + 10) -- Matches no row
+;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
index 0beebec5702fd..db00a18f2e7e9 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
@@ -91,3 +91,65 @@ FROM (
encode(string(id + 3), 'utf-8') col4
FROM range(10)
);
+
+CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
+ array(true, false), array(true),
+ array(2Y, 1Y), array(3Y, 4Y),
+ array(2S, 1S), array(3S, 4S),
+ array(2, 1), array(3, 4),
+ array(2L, 1L), array(3L, 4L),
+ array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
+ array(2.0D, 1.0D), array(3.0D, 4.0D),
+ array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
+ array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
+ array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+ array(timestamp '2016-11-11 20:54:00.000'),
+ array('a', 'b'), array('c', 'd'),
+ array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
+ array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
+ array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
+) AS various_arrays(
+ boolean_array1, boolean_array2,
+ tinyint_array1, tinyint_array2,
+ smallint_array1, smallint_array2,
+ int_array1, int_array2,
+ bigint_array1, bigint_array2,
+ decimal_array1, decimal_array2,
+ double_array1, double_array2,
+ float_array1, float_array2,
+ date_array1, data_array2,
+ timestamp_array1, timestamp_array2,
+ string_array1, string_array2,
+ array_array1, array_array2,
+ struct_array1, struct_array2,
+ map_array1, map_array2
+);
+
+-- Concatenate arrays of the same type
+SELECT
+ (boolean_array1 || boolean_array2) boolean_array,
+ (tinyint_array1 || tinyint_array2) tinyint_array,
+ (smallint_array1 || smallint_array2) smallint_array,
+ (int_array1 || int_array2) int_array,
+ (bigint_array1 || bigint_array2) bigint_array,
+ (decimal_array1 || decimal_array2) decimal_array,
+ (double_array1 || double_array2) double_array,
+ (float_array1 || float_array2) float_array,
+ (date_array1 || data_array2) data_array,
+ (timestamp_array1 || timestamp_array2) timestamp_array,
+ (string_array1 || string_array2) string_array,
+ (array_array1 || array_array2) array_array,
+ (struct_array1 || struct_array2) struct_array,
+ (map_array1 || map_array2) map_array
+FROM various_arrays;
+
+-- Concatenate arrays of different types
+SELECT
+ (tinyint_array1 || smallint_array2) ts_array,
+ (smallint_array1 || int_array2) si_array,
+ (int_array1 || bigint_array2) ib_array,
+ (double_array1 || float_array2) df_array,
+ (string_array1 || data_array2) std_array,
+ (timestamp_array1 || string_array2) tst_array,
+ (string_array1 || int_array2) sti_array
+FROM various_arrays;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
new file mode 100644
index 0000000000000..92c7e26e3add2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
@@ -0,0 +1,56 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+ (101, 1, 1, 1),
+ (201, 2, 1, 1),
+ (301, 3, 1, 1),
+ (401, 4, 1, 11),
+ (501, 5, 1, null),
+ (601, 6, null, 1),
+ (701, 6, null, null),
+ (102, 1, 2, 2),
+ (202, 2, 1, 2),
+ (302, 3, 2, 1),
+ (402, 4, 2, 12),
+ (502, 5, 2, null),
+ (602, 6, null, 2),
+ (702, 6, null, null),
+ (103, 1, 3, 3),
+ (203, 2, 1, 3),
+ (303, 3, 3, 1),
+ (403, 4, 3, 13),
+ (503, 5, 3, null),
+ (603, 6, null, 3),
+ (703, 6, null, null),
+ (104, 1, 4, 4),
+ (204, 2, 1, 4),
+ (304, 3, 4, 1),
+ (404, 4, 4, 14),
+ (504, 5, 4, null),
+ (604, 6, null, 4),
+ (704, 6, null, null),
+ (800, 7, 1, 1)
+as t1(id, px, y, x);
+
+select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
+ regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
+ regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
+from t1 group by px order by px;
+
+
+select id, regr_count(y,x) over (partition by px) from t1 order by id;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql
index e57d69eaad033..6da1b9b49b226 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/union.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql
@@ -35,6 +35,17 @@ FROM (SELECT col AS col
SELECT col
FROM p3) T1) T2;
+-- SPARK-24012 Union of map and other compatible columns.
+SELECT map(1, 2), 'str'
+UNION ALL
+SELECT map(1, 2, 3, NULL), 1;
+
+-- SPARK-24012 Union of array and other compatible columns.
+SELECT array(1, 2), 'str'
+UNION ALL
+SELECT array(1, 2, 3, NULL), 1;
+
+
-- Clean-up
DROP VIEW IF EXISTS t1;
DROP VIEW IF EXISTS t2;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index c4bea34ec4cf3..cda4db4b449fe 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile,
row_number() OVER w AS row_number,
var_pop(val) OVER w AS var_pop,
var_samp(val) OVER w AS var_samp,
-approx_count_distinct(val) OVER w AS approx_count_distinct
+approx_count_distinct(val) OVER w AS approx_count_distinct,
+covar_pop(val, val_long) OVER w AS covar_pop,
+corr(val, val_long) OVER w AS corr,
+stddev_samp(val) OVER w AS stddev_samp,
+stddev_pop(val) OVER w AS stddev_pop,
+collect_list(val) OVER w AS collect_list,
+collect_set(val) OVER w AS collect_set,
+skewness(val_double) OVER w AS skewness,
+kurtosis(val_double) OVER w AS kurtosis
FROM testData
WINDOW w AS (PARTITION BY cate ORDER BY val)
ORDER BY cate, val;
diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
index ba8bc936f0c79..ff1ecbcc44c23 100644
--- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 32
+-- Number of queries: 33
-- !query 0
@@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT
struct<>
-- !query 15 output
org.apache.spark.sql.AnalysisException
-Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))';
+Can't find column `invalid_col` given table data columns [`a`, `b`, `c`];
-- !query 16
@@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT
-- !query 30
-DROP TABLE test_change
+ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'
-- !query 30 schema
struct<>
-- !query 30 output
-
+org.apache.spark.sql.AnalysisException
+Can't find column `c` given table data columns [`a`, `b`];
-- !query 31
-DROP TABLE partition_table
+DROP TABLE test_change
-- !query 31 schema
struct<>
-- !query 31 output
+
+
+-- !query 32
+DROP TABLE partition_table
+-- !query 32 schema
+struct<>
+-- !query 32 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
index bbb6851e69c7e..9eede305dbdcc 100644
--- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 9
+-- Number of queries: 26
-- !query 0
@@ -81,3 +81,139 @@ struct
-- !query 8 output
1 2
2 3
+
+
+-- !query 9
+select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15')
+-- !query 9 schema
+struct
+-- !query 9 output
+5 3 5 NULL 4
+
+
+-- !query 10
+select from_utc_timestamp('2015-07-24 00:00:00', 'PST')
+-- !query 10 schema
+struct
+-- !query 10 output
+2015-07-23 17:00:00
+
+
+-- !query 11
+select from_utc_timestamp('2015-01-24 00:00:00', 'PST')
+-- !query 11 schema
+struct
+-- !query 11 output
+2015-01-23 16:00:00
+
+
+-- !query 12
+select from_utc_timestamp(null, 'PST')
+-- !query 12 schema
+struct
+-- !query 12 output
+NULL
+
+
+-- !query 13
+select from_utc_timestamp('2015-07-24 00:00:00', null)
+-- !query 13 schema
+struct
+-- !query 13 output
+NULL
+
+
+-- !query 14
+select from_utc_timestamp(null, null)
+-- !query 14 schema
+struct
+-- !query 14 output
+NULL
+
+
+-- !query 15
+select from_utc_timestamp(cast(0 as timestamp), 'PST')
+-- !query 15 schema
+struct
+-- !query 15 output
+1969-12-31 08:00:00
+
+
+-- !query 16
+select from_utc_timestamp(cast('2015-01-24' as date), 'PST')
+-- !query 16 schema
+struct
+-- !query 16 output
+2015-01-23 16:00:00
+
+
+-- !query 17
+select to_utc_timestamp('2015-07-24 00:00:00', 'PST')
+-- !query 17 schema
+struct
+-- !query 17 output
+2015-07-24 07:00:00
+
+
+-- !query 18
+select to_utc_timestamp('2015-01-24 00:00:00', 'PST')
+-- !query 18 schema
+struct
+-- !query 18 output
+2015-01-24 08:00:00
+
+
+-- !query 19
+select to_utc_timestamp(null, 'PST')
+-- !query 19 schema
+struct
+-- !query 19 output
+NULL
+
+
+-- !query 20
+select to_utc_timestamp('2015-07-24 00:00:00', null)
+-- !query 20 schema
+struct
+-- !query 20 output
+NULL
+
+
+-- !query 21
+select to_utc_timestamp(null, null)
+-- !query 21 schema
+struct
+-- !query 21 output
+NULL
+
+
+-- !query 22
+select to_utc_timestamp(cast(0 as timestamp), 'PST')
+-- !query 22 schema
+struct
+-- !query 22 output
+1970-01-01 00:00:00
+
+
+-- !query 23
+select to_utc_timestamp(cast('2015-01-24' as date), 'PST')
+-- !query 23 schema
+struct
+-- !query 23 output
+2015-01-24 08:00:00
+
+
+-- !query 24
+select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST')
+-- !query 24 schema
+struct
+-- !query 24 output
+NULL
+
+
+-- !query 25
+select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST')
+-- !query 25 schema
+struct
+-- !query 25 output
+NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out
index 75d190a6245bb..217233bfad378 100644
--- a/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 50
+-- Number of queries: 54
-- !query 0
@@ -114,313 +114,345 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000
-- !query 13
-select (5e36 + 0.1) + 5e36
+select 2.35E10 * 1.0
-- !query 13 schema
-struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)>
-- !query 13 output
-NULL
+23500000000
-- !query 14
-select (-4e36 - 0.1) - 7e36
+select (5e36 + 0.1) + 5e36
-- !query 14 schema
-struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 14 output
NULL
-- !query 15
-select 12345678901234567890.0 * 12345678901234567890.0
+select (-4e36 - 0.1) - 7e36
-- !query 15 schema
-struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
+struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 15 output
NULL
-- !query 16
-select 1e35 / 0.1
+select 12345678901234567890.0 * 12345678901234567890.0
-- !query 16 schema
-struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)>
+struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
-- !query 16 output
NULL
-- !query 17
-select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
+select 1e35 / 0.1
-- !query 17 schema
-struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)>
+struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)>
-- !query 17 output
-10012345678912345678912345678911.246907
+NULL
-- !query 18
-select 123456789123456789.1234567890 * 1.123456789123456789
+select 1.2345678901234567890E30 * 1.2345678901234567890E25
-- !query 18 schema
-struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)>
+struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)>
-- !query 18 output
-138698367904130467.654320988515622621
+NULL
-- !query 19
-select 12345678912345.123456789123 / 0.000000012345678
+select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
-- !query 19 schema
-struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)>
+struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)>
-- !query 19 output
-1000000073899961059796.725866332
+10012345678912345678912345678911.246907
-- !query 20
-set spark.sql.decimalOperations.allowPrecisionLoss=false
+select 123456789123456789.1234567890 * 1.123456789123456789
-- !query 20 schema
-struct
+struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)>
-- !query 20 output
-spark.sql.decimalOperations.allowPrecisionLoss false
+138698367904130467.654320988515622621
-- !query 21
-select id, a+b, a-b, a*b, a/b from decimals_test order by id
+select 12345678912345.123456789123 / 0.000000012345678
-- !query 21 schema
-struct
+struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)>
-- !query 21 output
-1 1099 -899 NULL 0.1001001001001001
-2 24690.246 0 NULL 1
-3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123
-4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436
+1000000073899961059796.725866332
-- !query 22
-select id, a*10, b/10 from decimals_test order by id
+set spark.sql.decimalOperations.allowPrecisionLoss=false
-- !query 22 schema
-struct
+struct
-- !query 22 output
-1 1000 99.9
-2 123451.23 1234.5123
-3 1.234567891011 123.41
-4 1234567891234567890 0.1123456789123456789
+spark.sql.decimalOperations.allowPrecisionLoss false
-- !query 23
-select 10.3 * 3.0
+select id, a+b, a-b, a*b, a/b from decimals_test order by id
-- !query 23 schema
-struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
+struct
-- !query 23 output
-30.9
+1 1099 -899 NULL 0.1001001001001001
+2 24690.246 0 NULL 1
+3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123
+4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436
-- !query 24
-select 10.3000 * 3.0
+select id, a*10, b/10 from decimals_test order by id
-- !query 24 schema
-struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
+struct
-- !query 24 output
-30.9
+1 1000 99.9
+2 123451.23 1234.5123
+3 1.234567891011 123.41
+4 1234567891234567890 0.1123456789123456789
-- !query 25
-select 10.30000 * 30.0
+select 10.3 * 3.0
-- !query 25 schema
-struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
+struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
-- !query 25 output
-309
+30.9
-- !query 26
-select 10.300000000000000000 * 3.000000000000000000
+select 10.3000 * 3.0
-- !query 26 schema
-struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
+struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
-- !query 26 output
30.9
-- !query 27
-select 10.300000000000000000 * 3.0000000000000000000
+select 10.30000 * 30.0
-- !query 27 schema
-struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)>
+struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
-- !query 27 output
-NULL
+309
-- !query 28
-select (5e36 + 0.1) + 5e36
+select 10.300000000000000000 * 3.000000000000000000
-- !query 28 schema
-struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
-- !query 28 output
-NULL
+30.9
-- !query 29
-select (-4e36 - 0.1) - 7e36
+select 10.300000000000000000 * 3.0000000000000000000
-- !query 29 schema
-struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
+struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)>
-- !query 29 output
NULL
-- !query 30
-select 12345678901234567890.0 * 12345678901234567890.0
+select 2.35E10 * 1.0
-- !query 30 schema
-struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
+struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)>
-- !query 30 output
-NULL
+23500000000
-- !query 31
-select 1e35 / 0.1
+select (5e36 + 0.1) + 5e36
-- !query 31 schema
-struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)>
+struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 31 output
NULL
-- !query 32
-select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
+select (-4e36 - 0.1) - 7e36
-- !query 32 schema
-struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)>
+struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)>
-- !query 32 output
NULL
-- !query 33
-select 123456789123456789.1234567890 * 1.123456789123456789
+select 12345678901234567890.0 * 12345678901234567890.0
-- !query 33 schema
-struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)>
+struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)>
-- !query 33 output
NULL
-- !query 34
-select 12345678912345.123456789123 / 0.000000012345678
+select 1e35 / 0.1
-- !query 34 schema
-struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)>
+struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)>
-- !query 34 output
NULL
-- !query 35
-set spark.sql.decimalOperations.nullOnOverflow=false
+select 1.2345678901234567890E30 * 1.2345678901234567890E25
-- !query 35 schema
-struct
+struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)>
-- !query 35 output
-spark.sql.decimalOperations.nullOnOverflow false
+NULL
-- !query 36
-select id, a*10, b/10 from decimals_test order by id
+select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345
-- !query 36 schema
-struct
+struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)>
-- !query 36 output
+NULL
+
+
+-- !query 37
+select 123456789123456789.1234567890 * 1.123456789123456789
+-- !query 37 schema
+struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)>
+-- !query 37 output
+NULL
+
+
+-- !query 38
+select 12345678912345.123456789123 / 0.000000012345678
+-- !query 38 schema
+struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)>
+-- !query 38 output
+NULL
+
+
+-- !query 39
+set spark.sql.decimalOperations.nullOnOverflow=false
+-- !query 39 schema
+struct
+-- !query 39 output
+spark.sql.decimalOperations.nullOnOverflow false
+
+
+-- !query 40
+select id, a*10, b/10 from decimals_test order by id
+-- !query 40 schema
+struct
+-- !query 40 output
1 1000 99.9
2 123451.23 1234.5123
3 1.234567891011 123.41
4 1234567891234567890 0.1123456789123456789
--- !query 37
+-- !query 41
select 10.3 * 3.0
--- !query 37 schema
+-- !query 41 schema
struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)>
--- !query 37 output
+-- !query 41 output
30.9
--- !query 38
+-- !query 42
select 10.3000 * 3.0
--- !query 38 schema
+-- !query 42 schema
struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)>
--- !query 38 output
+-- !query 42 output
30.9
--- !query 39
+-- !query 43
select 10.30000 * 30.0
--- !query 39 schema
+-- !query 43 schema
struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)>
--- !query 39 output
+-- !query 43 output
309
--- !query 40
+-- !query 44
select 10.300000000000000000 * 3.000000000000000000
--- !query 40 schema
+-- !query 44 schema
struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)>
--- !query 40 output
+-- !query 44 output
30.9
--- !query 41
+-- !query 45
select 10.300000000000000000 * 3.0000000000000000000
--- !query 41 schema
+-- !query 45 schema
struct<>
--- !query 41 output
+-- !query 45 output
java.lang.ArithmeticException
Decimal(expanded,30.900000000000000000000000000000000000,38,36}) cannot be represented as Decimal(38, 37).
--- !query 42
+-- !query 46
select (5e36 + 0.1) + 5e36
--- !query 42 schema
+-- !query 46 schema
struct<>
--- !query 42 output
+-- !query 46 output
java.lang.ArithmeticException
Decimal(expanded,10000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1).
--- !query 43
+-- !query 47
select (-4e36 - 0.1) - 7e36
--- !query 43 schema
+-- !query 47 schema
struct<>
--- !query 43 output
+-- !query 47 output
java.lang.ArithmeticException
Decimal(expanded,-11000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1).
--- !query 44
+-- !query 48
select 12345678901234567890.0 * 12345678901234567890.0
--- !query 44 schema
+-- !query 48 schema
struct<>
--- !query 44 output
+-- !query 48 output
java.lang.ArithmeticException
Decimal(expanded,1.5241578753238836750190519987501905210E+38,38,-1}) cannot be represented as Decimal(38, 2).
--- !query 45
+-- !query 49
select 1e35 / 0.1
--- !query 45 schema
+-- !query 49 schema
struct<>
--- !query 45 output
+-- !query 49 output
java.lang.ArithmeticException
Decimal(expanded,1000000000000000000000000000000000000,37,0}) cannot be represented as Decimal(38, 3).
--- !query 46
+-- !query 50
select 123456789123456789.1234567890 * 1.123456789123456789
--- !query 46 schema
+-- !query 50 schema
struct<>
--- !query 46 output
+-- !query 50 output
java.lang.ArithmeticException
Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28).
--- !query 47
+-- !query 51
select 123456789123456789.1234567890 * 1.123456789123456789
--- !query 47 schema
+-- !query 51 schema
struct<>
--- !query 47 output
+-- !query 51 output
java.lang.ArithmeticException
Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28).
--- !query 48
+-- !query 52
select 12345678912345.123456789123 / 0.000000012345678
--- !query 48 schema
+-- !query 52 schema
struct<>
--- !query 48 output
+-- !query 52 output
java.lang.ArithmeticException
Decimal(expanded,1000000073899961059796.7258663315210392,38,16}) cannot be represented as Decimal(38, 18).
--- !query 49
+-- !query 53
drop table decimals_test
--- !query 49 schema
+-- !query 53 schema
struct<>
--- !query 49 output
+-- !query 53 output
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
index 51dac111029e8..58ed201e2a60f 100644
--- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
@@ -89,7 +89,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -122,7 +122,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -147,7 +147,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
-Partition Statistics 1080 bytes, 4 rows
+Partition Statistics 1098 bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -180,7 +180,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
-Partition Statistics 1067 bytes, 3 rows
+Partition Statistics 1121 bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -205,7 +205,7 @@ Database default
Table t
Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
-Partition Statistics 1080 bytes, 4 rows
+Partition Statistics 1098 bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -230,7 +230,7 @@ Database default
Table t
Partition Values [ds=2017-09-01, hr=5]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5
-Partition Statistics 1054 bytes, 2 rows
+Partition Statistics 1144 bytes, 2 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
new file mode 100644
index 0000000000000..160e4c7d78455
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out
@@ -0,0 +1,96 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 11
+
+
+-- !query 0
+CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select extract(year from c) from t
+-- !query 1 schema
+struct
+-- !query 1 output
+2011
+
+
+-- !query 2
+select extract(quarter from c) from t
+-- !query 2 schema
+struct
+-- !query 2 output
+2
+
+
+-- !query 3
+select extract(month from c) from t
+-- !query 3 schema
+struct
+-- !query 3 output
+5
+
+
+-- !query 4
+select extract(week from c) from t
+-- !query 4 schema
+struct
+-- !query 4 output
+18
+
+
+-- !query 5
+select extract(day from c) from t
+-- !query 5 schema
+struct
+-- !query 5 output
+6
+
+
+-- !query 6
+select extract(dayofweek from c) from t
+-- !query 6 schema
+struct
+-- !query 6 output
+6
+
+
+-- !query 7
+select extract(hour from c) from t
+-- !query 7 schema
+struct
+-- !query 7 output
+7
+
+
+-- !query 8
+select extract(minute from c) from t
+-- !query 8 schema
+struct
+-- !query 8 output
+8
+
+
+-- !query 9
+select extract(second from c) from t
+-- !query 9 schema
+struct
+-- !query 9 output
+9
+
+
+-- !query 10
+select extract(not_supported from c) from t
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7)
+
+== SQL ==
+select extract(not_supported from c) from t
+-------^^^
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index c1abc6dff754b..581aa1754ce14 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 26
+-- Number of queries: 27
-- !query 0
@@ -241,3 +241,12 @@ where b.z != b.z
struct<1:int>
-- !query 25 output
+
+
+-- !query 26
+SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
+ FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y)
+-- !query 26 schema
+struct
+-- !query 26 output
+1.0 1.0 3
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
index 581dddc89d0bb..2b3288dc5a137 100644
--- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 26
+-- Number of queries: 28
-- !query 0
@@ -129,7 +129,7 @@ select to_json()
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
-Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7
+Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7
-- !query 13
@@ -225,7 +225,7 @@ select from_json()
struct<>
-- !query 21 output
org.apache.spark.sql.AnalysisException
-Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7
+Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7
-- !query 22
@@ -258,3 +258,19 @@ DROP VIEW IF EXISTS jsonTable
struct<>
-- !query 25 output
+
+
+-- !query 26
+select from_json('{"a":1, "b":2}', 'map')
+-- !query 26 schema
+struct>
+-- !query 26 output
+{"a":1,"b":2}
+
+
+-- !query 27
+select from_json('{"a":1, "b":"2"}', 'struct')
+-- !query 27 schema
+struct>
+-- !query 27 output
+{"a":1,"b":"2"}
diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
index 95d4413148f64..b8c91dc8b59a4 100644
--- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 43
+-- Number of queries: 44
-- !query 0
@@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000'
-- !query 34
select interval 13.123456789 seconds, interval -13.123456789 second
-- !query 34 schema
-struct<>
+struct
-- !query 34 output
-scala.MatchError
-(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2)
+interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds
-- !query 35
select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond
-- !query 35 schema
-struct<>
+struct
-- !query 35 output
-scala.MatchError
-(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2)
+interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9
-- !query 36
@@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8
struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)>
-- !query 42 output
3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314
+
+
+-- !query 43
+select map(1, interval 1 day, 2, interval 3 week)
+-- !query 43 schema
+struct