diff --git a/R/README.md b/R/README.md index 31174c7352..da9f042b4f 100644 --- a/R/README.md +++ b/R/README.md @@ -17,10 +17,14 @@ export R_HOME=/home/username/R #### Build Spark -Build Spark with [Maven](https://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run +Build Spark with [Maven](https://spark.apache.org/docs/latest/building-spark.html#buildmvn) or [SBT](https://spark.apache.org/docs/latest/building-spark.html#building-with-sbt), and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run ```bash +# Maven ./build/mvn -DskipTests -Psparkr package + +# SBT +./build/sbt -Psparkr package ``` #### Running sparkR diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 5b49a01395..20339c947d 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -38,14 +38,14 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-6.8823988, -0.6154984, -1.5135447, 1.9694126, 3.3736856) expect_true(all(abs(coefs - expected_coefs) < 0.1)) # Test prediction with string label prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") - expected <- c("versicolor", "versicolor", "versicolor", "virginica", "virginica", - "virginica", "virginica", "virginica", "virginica", "virginica") + expected <- c("versicolor", "versicolor", "versicolor", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load diff --git a/build/mvn b/build/mvn index 719d7573f4..4e53a16bcd 100755 --- a/build/mvn +++ b/build/mvn @@ -31,7 +31,7 @@ _COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g" ## Arg2 - Tarball Name ## Arg3 - Checkable Binary install_app() { - local remote_tarball="$1/$2" + local remote_tarball="$1" local local_tarball="${_DIR}/$2" local binary="${_DIR}/$3" @@ -71,19 +71,20 @@ install_mvn() { local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')" fi if [ $(version $MVN_DETECTED_VERSION) -lt $(version $MVN_VERSION) ]; then - local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='} - + local FILE_PATH="maven/maven-3/${MVN_VERSION}/binaries/apache-maven-${MVN_VERSION}-bin.tar.gz" + local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua'} + local MIRROR_URL="${APACHE_MIRROR}/${FILE_PATH}?action=download" + if [ $(command -v curl) ]; then - local TEST_MIRROR_URL="${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries/apache-maven-${MVN_VERSION}-bin.tar.gz" - if ! curl -L --output /dev/null --silent --head --fail "$TEST_MIRROR_URL" ; then + if ! curl -L --output /dev/null --silent --head --fail "${MIRROR_URL}" ; then # Fall back to archive.apache.org for older Maven echo "Falling back to archive.apache.org to download Maven" - APACHE_MIRROR="https://archive.apache.org/dist" + MIRROR_URL="https://archive.apache.org/dist/${FILE_PATH}" fi fi install_app \ - "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \ + "${MIRROR_URL}" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -102,7 +103,7 @@ install_scala() { local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ - "${TYPESAFE_MIRROR}/scala/${scala_version}" \ + "${TYPESAFE_MIRROR}/scala/${scala_version}/scala-${scala_version}.tgz" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index dc7b9ea2d2..5db8c5c295 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,10 +22,12 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Set the default spark-shell log level to WARN. When running the spark-shell, the -# log level for this class is used to overwrite the root logger's log level, so that -# the user can have different defaults for the shell and regular Spark apps. +# Set the default spark-shell/spark-sql log level to WARN. When running the +# spark-shell/spark-sql, the log level for these classes is used to overwrite +# the root logger's log level, so that the user can have different defaults +# for the shell and regular Spark apps. log4j.logger.org.apache.spark.repl.Main=WARN +log4j.logger.org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver=WARN # Settings to quiet third party logs that are too verbose log4j.logger.org.sparkproject.jetty=WARN diff --git a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js index 91bf274aa4..8b32fe7d3e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/stagepage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/stagepage.js @@ -43,6 +43,23 @@ $.extend( $.fn.dataTable.ext.type.order, { a = ConvertDurationString( a ); b = ConvertDurationString( b ); return ((a < b) ? 1 : ((a > b) ? -1 : 0)); + }, + + "size-pre": function (data) { + var floatValue = parseFloat(data) + return isNaN(floatValue) ? 0 : floatValue; + }, + + "size-asc": function (a, b) { + a = parseFloat(a); + b = parseFloat(b); + return ((a < b) ? -1 : ((a > b) ? 1 : 0)); + }, + + "size-desc": function (a, b) { + a = parseFloat(a); + b = parseFloat(b); + return ((a < b) ? 1 : ((a > b) ? -1 : 0)); } } ); @@ -562,10 +579,27 @@ $(document).ready(function () { } ], "columnDefs": [ - { "visible": false, "targets": 15 }, - { "visible": false, "targets": 16 }, - { "visible": false, "targets": 17 }, - { "visible": false, "targets": 18 } + // SPARK-35087 [type:size] means String with structures like : 'size / records', + // they should be sorted as numerical-order instead of lexicographical-order by default. + // The targets: $id represents column id which comes from stagespage-template.html + // #summary-executor-table.If the relative position of the columns in the table + // #summary-executor-table has changed,please be careful to adjust the column index here + // Input Size / Records + {"type": "size", "targets": 9}, + // Output Size / Records + {"type": "size", "targets": 10}, + // Shuffle Read Size / Records + {"type": "size", "targets": 11}, + // Shuffle Write Size / Records + {"type": "size", "targets": 12}, + // Peak JVM Memory OnHeap / OffHeap + {"visible": false, "targets": 15}, + // Peak Execution Memory OnHeap / OffHeap + {"visible": false, "targets": 16}, + // Peak Storage Memory OnHeap / OffHeap + {"visible": false, "targets": 17}, + // Peak Pool Memory Direct / Mapped + {"visible": false, "targets": 18} ], "deferRender": true, "order": [[0, "asc"]], @@ -746,7 +780,7 @@ $(document).ready(function () { "paging": true, "info": true, "processing": true, - "lengthMenu": [[20, 40, 60, 100, totalTasksToShow], [20, 40, 60, 100, "All"]], + "lengthMenu": [[20, 40, 60, 100, -1], [20, 40, 60, 100, "All"]], "orderMulti": false, "bAutoWidth": false, "ajax": { @@ -762,6 +796,9 @@ $(document).ready(function () { data.numTasks = totalTasksToShow; data.columnIndexToSort = columnIndexToSort; data.columnNameToSort = columnNameToSort; + if (data.length === -1) { + data.length = totalTasksToShow; + } }, "dataSrc": function (jsons) { var jsonStr = JSON.stringify(jsons); diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 822a0a5d5e..779559b116 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -290,11 +290,11 @@ private[spark] class ExecutorAllocationManager( * under the current load to satisfy all running and pending tasks, rounded up. */ private[spark] def maxNumExecutorsNeededPerResourceProfile(rpId: Int): Int = { - val pending = listener.totalPendingTasksPerResourceProfile(rpId) + val pendingTask = listener.pendingTasksPerResourceProfile(rpId) val pendingSpeculative = listener.pendingSpeculativeTasksPerResourceProfile(rpId) val unschedulableTaskSets = listener.pendingUnschedulableTaskSetsPerResourceProfile(rpId) val running = listener.totalRunningTasksPerResourceProfile(rpId) - val numRunningOrPendingTasks = pending + running + val numRunningOrPendingTasks = pendingTask + pendingSpeculative + running val rp = resourceProfileManager.resourceProfileFromId(rpId) val tasksPerExecutor = rp.maxTasksPerExecutor(conf) logDebug(s"max needed for rpId: $rpId numpending: $numRunningOrPendingTasks," + @@ -916,18 +916,6 @@ private[spark] class ExecutorAllocationManager( hasPendingSpeculativeTasks || hasPendingRegularTasks } - def totalPendingTasksPerResourceProfile(rp: Int): Int = { - pendingTasksPerResourceProfile(rp) + pendingSpeculativeTasksPerResourceProfile(rp) - } - - /** - * The number of tasks currently running across all stages. - * Include running-but-zombie stage attempts - */ - def totalRunningTasks(): Int = { - stageAttemptToNumRunningTask.values.sum - } - def totalRunningTasksPerResourceProfile(rp: Int): Int = { val attempts = resourceProfileIdToStageAttempt.getOrElse(rp, Set.empty).toSeq // attempts is a Set, change to Seq so we keep all values diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ce71c2c7bc..b749d7e862 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -29,13 +29,14 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream} +import org.roaringbitmap.RoaringBitmap import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -49,7 +50,9 @@ import org.apache.spark.util._ * * All public methods of this class are thread-safe. */ -private class ShuffleStatus(numPartitions: Int) extends Logging { +private class ShuffleStatus( + numPartitions: Int, + numReducers: Int = -1) extends Logging { private val (readLock, writeLock) = { val lock = new ReentrantReadWriteLock() @@ -86,6 +89,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { // Exposed for testing val mapStatuses = new Array[MapStatus](numPartitions) + /** + * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the + * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for + * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array + * provides a reducer oriented view of the shuffle status specifically for the results of + * merging shuffle partition blocks into per-partition merged shuffle files. + */ + val mergeStatuses = if (numReducers > 0) { + new Array[MergeStatus](numReducers) + } else { + Array.empty[MergeStatus] + } + /** * The cached result of serializing the map statuses array. This cache is lazily populated when * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. @@ -102,12 +118,24 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ private[spark] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + /** + * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus. + */ + private[this] var cachedSerializedMergeStatus: Array[Byte] = _ + + private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _ + /** * Counter tracking the number of partitions that have output. This is a performance optimization * to avoid having to count the number of non-null entries in the `mapStatuses` array and should * be equivalent to`mapStatuses.count(_ ne null)`. */ - private[this] var _numAvailableOutputs: Int = 0 + private[this] var _numAvailableMapOutputs: Int = 0 + + /** + * Counter tracking the number of MergeStatus results received so far from the shuffle services. + */ + private[this] var _numAvailableMergeResults: Int = 0 /** * Register a map output. If there is already a registered location for the map output then it @@ -115,7 +143,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { if (mapStatuses(mapIndex) == null) { - _numAvailableOutputs += 1 + _numAvailableMapOutputs += 1 invalidateSerializedMapOutputStatusCache() } mapStatuses(mapIndex) = status @@ -149,12 +177,36 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock { logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}") if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } + /** + * Register a merge result. + */ + def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock { + if (mergeStatuses(reduceId) != status) { + _numAvailableMergeResults += 1 + invalidateSerializedMergeOutputStatusCache() + } + mergeStatuses(reduceId) = status + } + + // TODO support updateMergeResult for similar use cases as updateMapOutput + + /** + * Remove the merge result which was served by the specified block manager. + */ + def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock { + if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + /** * Removes all shuffle outputs associated with this host. Note that this will also remove * outputs which are served by an external shuffle server (if one exists). @@ -181,18 +233,33 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (mapIndex <- mapStatuses.indices) { if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { - _numAvailableOutputs -= 1 + _numAvailableMapOutputs -= 1 mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } + for (reduceId <- mergeStatuses.indices) { + if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) { + _numAvailableMergeResults -= 1 + mergeStatuses(reduceId) = null + invalidateSerializedMergeOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle map outputs. + */ + def numAvailableMapOutputs: Int = withReadLock { + _numAvailableMapOutputs } /** - * Number of partitions that have shuffle outputs. + * Number of shuffle partitions that have already been merge finalized when push-based + * is enabled. */ - def numAvailableOutputs: Int = withReadLock { - _numAvailableOutputs + def numAvailableMergeResults: Int = withReadLock { + _numAvailableMergeResults } /** @@ -200,19 +267,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { */ def findMissingPartitions(): Seq[Int] = withReadLock { val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + assert(missing.size == numPartitions - _numAvailableMapOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}") missing } /** * Serializes the mapStatuses array into an efficient compressed format. See the comments on - * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format. * * This method is designed to be called multiple times and implements caching in order to speed * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to - * serialize the map statuses then serialization will only be performed in a single thread and all - * other threads will block until the cache is populated. + * serialize the map statuses then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. */ def serializedMapStatus( broadcastManager: BroadcastManager, @@ -220,7 +287,6 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { minBroadcastSize: Int, conf: SparkConf): Array[Byte] = { var result: Array[Byte] = null - withReadLock { if (cachedSerializedMapStatus != null) { result = cachedSerializedMapStatus @@ -229,7 +295,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { if (result == null) withWriteLock { if (cachedSerializedMapStatus == null) { - val serResult = MapOutputTracker.serializeMapStatuses( + val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus]( mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf) cachedSerializedMapStatus = serResult._1 cachedSerializedBroadcast = serResult._2 @@ -241,6 +307,47 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { result } + /** + * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format. + * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details + * on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the statuses array then serialization will only be performed in a single thread and + * all other threads will block until the cache is populated. + */ + def serializedMapAndMergeStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int, + conf: SparkConf): (Array[Byte], Array[Byte]) = { + val mapStatusesBytes: Array[Byte] = + serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf) + var mergeStatusesBytes: Array[Byte] = null + + withReadLock { + if (cachedSerializedMergeStatus != null) { + mergeStatusesBytes = cachedSerializedMergeStatus + } + } + + if (mergeStatusesBytes == null) withWriteLock { + if (cachedSerializedMergeStatus == null) { + val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus]( + mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf) + cachedSerializedMergeStatus = serResult._1 + cachedSerializedBroadcastMergeStatus = serResult._2 + } + + // The following line has to be outside if statement since it's possible that another + // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and + // `withWriteLock`. + mergeStatusesBytes = cachedSerializedMergeStatus + } + (mapStatusesBytes, mergeStatusesBytes) + } + // Used in testing. def hasCachedSerializedBroadcast: Boolean = withReadLock { cachedSerializedBroadcast != null @@ -254,6 +361,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { f(mapStatuses) } + def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock { + f(mergeStatuses) + } + /** * Clears the cached serialized map output statuses. */ @@ -269,14 +380,35 @@ private class ShuffleStatus(numPartitions: Int) extends Logging { } cachedSerializedMapStatus = null } + + /** + * Clears the cached serialized merge result statuses. + */ + def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock { + if (cachedSerializedBroadcastMergeStatus != null) { + Utils.tryLogNonFatalError { + // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup + // RPCs to dead executors. + cachedSerializedBroadcastMergeStatus.destroy() + } + cachedSerializedBroadcastMergeStatus = null + } + cachedSerializedMergeStatus = null + } } private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage +private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int) + extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) +private[spark] sealed trait MapOutputTrackerMasterMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage +private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int, + context: RpcCallContext) extends MapOutputTrackerMasterMessage /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( @@ -288,8 +420,13 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort - logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}") - tracker.post(new GetMapOutputMessage(shuffleId, context)) + logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapOutputMessage(shuffleId, context)) + + case GetMapAndMergeResultStatuses(shuffleId: Int) => + val hostPort = context.senderAddress.hostPort + logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort") + tracker.post(GetMapAndMergeOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -367,6 +504,40 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors upon fetch failure on an entire merged shuffle reduce partition. + * Such failures can happen if the shuffle client fails to fetch the metadata for the given + * merged shuffle partition. This method is to get the server URIs and output sizes for each + * shuffle block that is merged in the specified merged shuffle block so fetch failure on a + * merged shuffle block can fall back to fetching the unmerged blocks. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + /** + * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is + * to get the server URIs and output sizes for each shuffle block that is merged in the specified + * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back + * to fetching the unmerged blocks. + * + * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is + * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original + * blocks part of this merged chunk. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -415,8 +586,11 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // requests for map output statuses - private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + // requests for MapOutputTrackerMasterMessages + private val mapOutputTrackerMasterMessages = + new LinkedBlockingQueue[MapOutputTrackerMasterMessage] + + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf) // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -439,31 +613,47 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetMapOutputMessage): Unit = { - mapOutputRequests.offer(message) + def post(message: MapOutputTrackerMasterMessage): Unit = { + mapOutputTrackerMasterMessages.offer(message) } /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { + private def handleStatusMessage( + shuffleId: Int, + context: RpcCallContext, + needMergeOutput: Boolean): Unit = { + val hostPort = context.senderAddress.hostPort + val shuffleStatus = shuffleStatuses.get(shuffleId).head + logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" + + s" output locations for shuffle $shuffleId to $hostPort") + if (needMergeOutput) { + context.reply( + shuffleStatus. + serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } else { + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf)) + } + } + override def run(): Unit = { try { while (true) { try { - val data = mapOutputRequests.take() - if (data == PoisonPill) { + val data = mapOutputTrackerMasterMessages.take() + if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it. - mapOutputRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) return } - val context = data.context - val shuffleId = data.shuffleId - val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - val shuffleStatus = shuffleStatuses.get(shuffleId).head - context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, - conf)) + + data match { + case GetMapOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, false) + case GetMapAndMergeOutputMessage(shuffleId, context) => + handleStatusMessage(shuffleId, context, true) + } } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -475,16 +665,22 @@ private[spark] class MapOutputTrackerMaster( } /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new GetMapOutputMessage(-99, null) + private val PoisonPill = GetMapOutputMessage(-99, null) // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) } - def registerShuffle(shuffleId: Int, numMaps: Int): Unit = { - if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { + if (pushBasedShuffleEnabled) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } else { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } } } @@ -524,10 +720,49 @@ private[spark] class MapOutputTrackerMaster( } } + def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) { + shuffleStatuses(shuffleId).addMergeResult(reduceId, status) + } + + def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = { + statuses.foreach { + case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status) + } + } + + /** + * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId + * is specified, it will only unregister the merge result if the mapId is part of that merge + * result. + * + * @param shuffleId the shuffleId. + * @param reduceId the reduceId. + * @param bmAddress block manager address. + * @param mapId the optional mapId which should be checked to see it was part of the merge + * result. + */ + def unregisterMergeResult( + shuffleId: Int, + reduceId: Int, + bmAddress: BlockManagerId, + mapId: Option[Int] = None) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + val mergeStatus = shuffleStatus.mergeStatuses(reduceId) + if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) { + shuffleStatus.removeMergeResult(reduceId, bmAddress) + incrementEpoch() + } + case None => + throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int): Unit = { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() } } @@ -554,7 +789,12 @@ private[spark] class MapOutputTrackerMaster( def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) def getNumAvailableOutputs(shuffleId: Int): Int = { - shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0) + } + + /** VisibleForTest. Invoked in test only. */ + private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0) } /** @@ -633,7 +873,9 @@ private[spark] class MapOutputTrackerMaster( /** * Return the preferred hosts on which to run the given map output partition in a given shuffle, - * i.e. the nodes that the most outputs for that partition are on. + * i.e. the nodes that the most outputs for that partition are on. If the map output is + * pre-merged, then return the node where the merged block is located if the merge ratio is + * above the threshold. * * @param dep shuffle dependency object * @param partitionId map output partition that we want to read @@ -641,15 +883,40 @@ private[spark] class MapOutputTrackerMaster( */ def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) : Seq[String] = { - if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && - dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { - val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, - dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) - if (blockManagerIds.nonEmpty) { - blockManagerIds.get.map(_.host) + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + // Check if the map output is pre-merged and if the merge ratio is above the threshold. + // If so, the location of the merged block is the preferred location. + val preferredLoc = if (pushBasedShuffleEnabled) { + shuffleStatus.withMergeStatuses { statuses => + val status = statuses(partitionId) + val numMaps = dep.rdd.partitions.length + if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps + <= (1 - REDUCER_PREF_LOCS_FRACTION)) { + Seq(status.location.host) + } else { + Nil + } + } } else { Nil } + if (preferredLoc.nonEmpty) { + preferredLoc + } else { + if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } } else { Nil } @@ -774,8 +1041,25 @@ private[spark] class MapOutputTrackerMaster( } } + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator + } + + // This method is only called in local-mode. Since push based shuffle won't be + // enabled in local-mode, this method returns empty list. + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq.empty.toIterator + } + override def stop(): Unit = { - mapOutputRequests.offer(PoisonPill) + mapOutputTrackerMasterMessages.offer(PoisonPill) threadpool.shutdown() try { sendTracker(StopMapOutputTracker) @@ -799,6 +1083,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + val mergeStatuses: Map[Int, Array[MergeStatus]] = + new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala + + private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf) + /** * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching * the same shuffle block. @@ -812,61 +1101,150 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId") - val statuses = getStatuses(shuffleId, conf) + val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf) try { - val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex + val actualEndMapIndex = + if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex logDebug(s"Convert map statuses for shuffle $shuffleId, " + s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition") MapOutputTracker.convertMapStatuses( - shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex) + shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex, + actualEndMapIndex, Option(mergedOutputStatuses)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf) + try { + val mergeStatus = mergeResultStatuses(partitionId) + // If the original MergeStatus is no longer available, we cannot identify the list of + // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case. + MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId) + // Use the MergeStatus's partition level bitmap since we are doing partition level fallback + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, + mapOutputStatuses, mergeStatus.tracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() + throw e + } + } + + override def getMapSizesForMergeResult( + shuffleId: Int, + partitionId: Int, + chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId") + // Fetch the map statuses and merge statuses again since they might have already been + // cleared by another task running in the same executor. + val (mapOutputStatuses, _) = getStatuses(shuffleId, conf) + try { + MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses, + chunkTracker) + } catch { + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + case e: MetadataFetchFailedException => + mapStatuses.clear() + mergeStatuses.clear() throw e } } /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled + * for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. * * (It would be nice to remove this restriction in the future.) */ - private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTimeNs = System.nanoTime() - fetchingLock.withLock(shuffleId) { - var fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - try { - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) - } catch { - case e: SparkException => - throw new MetadataFetchFailedException(shuffleId, -1, - s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + - e.getCause) + private def getStatuses( + shuffleId: Int, + conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = { + if (fetchMergeResult) { + val mapOutputStatuses = mapStatuses.get(shuffleId).orNull + val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull + + if (mapOutputStatuses == null || mergeOutputStatuses == null) { + logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull + var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull + if (fetchedMapStatuses == null || fetchedMergeStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = + askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId)) + try { + fetchedMapStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf) + fetchedMergeStatuses = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map/merge statuses" + + s" for shuffle $shuffleId: " + e.getCause) + } + logInfo("Got the map/merge output locations") + mapStatuses.put(shuffleId, fetchedMapStatuses) + mergeStatuses.put(shuffleId, fetchedMergeStatuses) } - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedMapStatuses, fetchedMergeStatuses) } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - fetchedStatuses + } else { + (mapOutputStatuses, mergeOutputStatuses) } } else { - statuses + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTimeNs = System.nanoTime() + fetchingLock.withLock(shuffleId) { + var fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + try { + fetchedStatuses = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) + } catch { + case e: SparkException => + throw new MetadataFetchFailedException(shuffleId, -1, + s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " + + e.getCause) + } + logInfo("Got the map output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + (fetchedStatuses, null) + } + } else { + (statuses, null) + } } } - /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) + mergeStatuses.remove(shuffleId) } /** @@ -880,6 +1258,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr logInfo("Updating epoch to " + newEpoch + " and clearing cache") epoch = newEpoch mapStatuses.clear() + mergeStatuses.clear() } } } @@ -891,11 +1270,13 @@ private[spark] object MapOutputTracker extends Logging { private val DIRECT = 0 private val BROADCAST = 1 - // Serialize an array of map output locations into an efficient byte format so that we can send - // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will - // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses( - statuses: Array[MapStatus], + private val SHUFFLE_PUSH_MAP_ID = -1 + + // Serialize an array of map/merge output locations into an efficient byte format so that we can + // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will + // generally be pretty compressible because many outputs will be on the same hostname. + def serializeOutputStatuses[T <: ShuffleOutputStatus]( + statuses: Array[T], broadcastManager: BroadcastManager, isLocal: Boolean, minBroadcastSize: Int, @@ -931,15 +1312,16 @@ private[spark] object MapOutputTracker extends Logging { oos.close() } val outArr = out.toByteArray - logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arr.length) (outArr, bcast) } else { (arr, null) } } - // Opposite of serializeMapStatuses. - def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = { + // Opposite of serializeOutputStatuses. + def deserializeOutputStatuses[T <: ShuffleOutputStatus]( + bytes: Array[Byte], conf: SparkConf): Array[T] = { assert (bytes.length > 0) def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { @@ -958,20 +1340,22 @@ private[spark] object MapOutputTracker extends Logging { bytes(0) match { case DIRECT => - deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[T]] case BROADCAST => try { // deserialize the Broadcast, pull .value array out of it, and then deserialize that val bcast = deserializeObject(bytes, 1, bytes.length - 1). asInstanceOf[Broadcast[Array[Byte]]] - logInfo("Broadcast mapstatuses size = " + bytes.length + + logInfo("Broadcast outputstatuses size = " + bytes.length + ", actual size = " + bcast.value.length) // Important - ignore the DIRECT tag ! Start from offset 1 - deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[T]] } catch { case e: IOException => - logWarning("Exception encountered during deserializing broadcasted map statuses: ", e) - throw new SparkException("Unable to deserialize broadcasted map statuses", e) + logWarning("Exception encountered during deserializing broadcasted" + + " output statuses: ", e) + throw new SparkException("Unable to deserialize broadcasted" + + " output statuses", e) } case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } @@ -983,15 +1367,19 @@ private[spark] object MapOutputTracker extends Logging { * stored at that block manager. * Note that empty blocks are filtered in the result. * + * If push-based shuffle is enabled and an array of merge statuses is available, prioritize + * the locations of the merged shuffle partitions over unmerged shuffle blocks. + * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. * * @param shuffleId Identifier for the shuffle * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) - * @param statuses List of map statuses, indexed by map partition index. + * @param mapStatuses List of map statuses, indexed by map partition index. * @param startMapIndex Start Map index. * @param endMapIndex End Map index. + * @param mergeStatuses List of merge statuses, index by reduce ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block id, shuffle block size, map index) * tuples describing the shuffle blocks that are stored at that block manager. @@ -1000,18 +1388,57 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus], + mapStatuses: Array[MapStatus], startMapIndex : Int, - endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - assert (statuses != null) + endMapIndex: Int, + mergeStatuses: Option[Array[MergeStatus]] = None): + Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] - val iter = statuses.iterator.zipWithIndex - for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { - if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) - } else { + // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle + // partition consists of blocks merged in random order, we are unable to serve map index + // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of + // map outputs, it usually indicates skewed partitions which push-based shuffle delegates + // to AQE to handle. + // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle, + // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end + // TODO: map indexes + if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0 + && endMapIndex == mapStatuses.length) { + // We have MergeStatus and full range of mapIds are requested so return a merged block. + val numMaps = mapStatuses.length + mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach { + case (mergeStatus, partId) => + val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) { + // If MergeStatus is available for the given partition, add location of the + // pre-merged shuffle partition for this partition ID. Here we create a + // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is + // a merged shuffle block. + splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) + // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper + // shuffle partition blocks, fetch the original map produced shuffle partition blocks + val mapStatusesWithIndex = mapStatuses.zipWithIndex + mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex) + } else { + // If MergeStatus is not available for the given partition, fall back to + // fetching all the original mapper shuffle partition blocks + mapStatuses.zipWithIndex.toSeq + } + // Add location for the mapper shuffle partition blocks + for ((mapStatus, mapIndex) <- remainingMapStatuses) { + validateStatus(mapStatus, shuffleId, partId) + val size = mapStatus.getSizeForBlock(partId) + if (size != 0) { + splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex)) + } + } + } + } else { + val iter = mapStatuses.iterator.zipWithIndex + for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { + validateStatus(status, shuffleId, startPartition) for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { @@ -1024,4 +1451,47 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } + + /** + * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding + * to either a merged shuffle partition or a merged shuffle partition chunk, identify + * the metadata about the shuffle partition blocks that are merged into the merged shuffle + * partition or partition chunk represented by the bitmap. + * + * @param shuffleId Identifier for the shuffle + * @param partitionId The partition ID of the MergeStatus for which we look for the metadata + * of the merged shuffle partition blocks + * @param mapStatuses List of map statuses, indexed by map ID + * @param tracker bitmap containing mapIndexes that belong to the merged block or merged + * block chunk. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapStatusesForMergeStatus( + shuffleId: Int, + partitionId: Int, + mapStatuses: Array[MapStatus], + tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + assert (mapStatuses != null && tracker != null) + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + for ((status, mapIndex) <- mapStatuses.zipWithIndex) { + // Only add blocks that are merged + if (tracker.contains(mapIndex)) { + MapOutputTracker.validateStatus(status, shuffleId, partitionId) + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapId, partitionId), + status.getSizeForBlock(partitionId), mapIndex)) + } + } + splitsByAddress.mapValues(_.toSeq).iterator + } + + def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, partition, errorMessage) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 36873c7f19..568bcf9b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -955,6 +955,15 @@ private[spark] class SparkSubmit extends Logging { } catch { case t: Throwable => throw findCause(t) + } finally { + if (!isShell(args.primaryResource) && !isSqlShell(args.mainClass) && + !isThriftServer(args.mainClass)) { + try { + SparkContext.getActive.foreach(_.stop()) + } catch { + case e: Throwable => logError(s"Failed to close SparkContext: $e") + } + } } } @@ -1189,7 +1198,7 @@ private[spark] object SparkSubmitUtils extends Logging { sp.setM2compatible(true) sp.setUsepoms(true) sp.setRoot(sys.env.getOrElse( - "DEFAULT_ARTIFACT_REPOSITORY", "https://dl.bintray.com/spark-packages/maven")) + "DEFAULT_ARTIFACT_REPOSITORY", "https://repos.spark-packages.org/")) sp.setName("spark-packages") cr.add(sp) cr diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2e7c4dae0..a92d9fab6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -456,7 +456,8 @@ private[spark] class DAGScheduler( // since we can't do it in the RDD constructor because # of partitions is unknown logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + s"shuffle ${shuffleDep.shuffleId}") - mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length, + shuffleDep.partitioner.numPartitions) } stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1239c32cee..07eed76805 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -28,12 +28,18 @@ import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +/** + * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing + * code to handle MergeStatus inside MapOutputTracker. + */ +private[spark] trait ShuffleOutputStatus + /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing * on to the reduce tasks. */ -private[spark] sealed trait MapStatus { +private[spark] sealed trait MapStatus extends ShuffleOutputStatus { /** Location where this task output is. */ def location: BlockManagerId diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala new file mode 100644 index 0000000000..77d8f8e040 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.network.shuffle.protocol.MergeStatuses +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +/** + * The status for the result of merging shuffle partition blocks per individual shuffle partition + * maintained by the scheduler. The scheduler would separate the + * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from + * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside + * MapOutputTracker to be served to the reducers when they start fetching shuffle partition + * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged + * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]]. + * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle + * partition meta file, we are effectively dividing the metadata for a push-based shuffle into + * 2 layers. The scheduler would track the top-level metadata at the shuffle partition level + * with MergeStatus, and the shuffle service would maintain the partition level metadata about + * how to further divide a merged shuffle partition into multiple chunks with the per-partition + * meta file. This helps to reduce the amount of data the scheduler needs to maintain for + * push-based shuffle. + */ +private[spark] class MergeStatus( + private[this] var loc: BlockManagerId, + private[this] var mapTracker: RoaringBitmap, + private[this] var size: Long) + extends Externalizable with ShuffleOutputStatus { + + protected def this() = this(null, null, -1) // For deserialization only + + def location: BlockManagerId = loc + + def totalSize: Long = size + + def tracker: RoaringBitmap = mapTracker + + /** + * Get the list of mapper IDs for missing mapper partition blocks that are not merged. + * The reducer will use this information to decide which shuffle partition blocks to + * fetch in the original way. + */ + def getMissingMaps(numMaps: Int): Seq[Int] = { + (0 until numMaps).filter(i => !mapTracker.contains(i)) + } + + /** + * Get the number of missing map outputs for missing mapper partition blocks that are not merged. + */ + def getNumMissingMapOutputs(numMaps: Int): Int = { + (0 until numMaps).count(i => !mapTracker.contains(i)) + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + mapTracker.writeExternal(out) + out.writeLong(size) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + mapTracker = new RoaringBitmap() + mapTracker.readExternal(in) + size = in.readLong() + } +} + +private[spark] object MergeStatus { + // Dummy number of reduces for the tests where push based shuffle is not enabled + val SHUFFLE_PUSH_DUMMY_NUM_REDUCES = 1 + + /** + * Separate a MergeStatuses received from an ExternalShuffleService into individual + * MergeStatus. The scheduler is responsible for providing the location information + * for the given ExternalShuffleService. + */ + def convertMergeStatusesToMergeStatusArr( + mergeStatuses: MergeStatuses, + loc: BlockManagerId): Seq[(Int, MergeStatus)] = { + assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length && + mergeStatuses.bitmaps.length == mergeStatuses.sizes.length) + val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port) + mergeStatuses.bitmaps.zipWithIndex.map { + case (bitmap, index) => + val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index)) + (mergeStatuses.reduceIds(index), mergeStatus) + } + } + + def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = { + new MergeStatus(loc, bitmap, size) + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 83fe450425..f4b47e2bb0 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -21,17 +21,19 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ +import org.roaringbitmap.RoaringBitmap import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} +import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} -import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} +import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -class MapOutputTrackerSuite extends SparkFunSuite { +class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf private def newTrackerMaster(sparkConf: SparkConf = conf) = { @@ -58,7 +60,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) @@ -82,7 +84,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -105,7 +107,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -140,7 +142,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { mapWorkerTracker.trackerEndpoint = mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) mapWorkerTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } @@ -183,7 +185,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) // Message size should be ~123B, and no exception should be thrown - masterTracker.registerShuffle(10, 1) + masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5)) val senderAddress = RpcAddress("localhost", 12345) @@ -217,7 +219,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostA with output size 2 // on hostA with output size 2 // on hostB with output size 3 - tracker.registerShuffle(10, 3) + tracker.registerShuffle(10, 3, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(2L), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), @@ -260,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before // being sent. - masterTracker.registerShuffle(20, 100) + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) @@ -306,7 +308,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) - tracker.registerShuffle(10, 2) + tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -332,6 +334,219 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.shutdown() } + test("SPARK-32921: master register and unregister merge result") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap, 1000L)) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.stop() + rpcEnv.shutdown() + } + + test("SPARK-32921: get map sizes with merged shuffle") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 3000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1), + (ShuffleBlockId(10, 2, 0), size1000, 2))))) + + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: get map statuses from merged shuffle") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + val bitmap = new RoaringBitmap() + bitmap.add(0) + bitmap.add(1) + bitmap.add(2) + bitmap.add(3) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId, + bitmap, 4000L)) + slaveTracker.updateEpoch(masterTracker.getEpoch) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2), + (ShuffleBlockId(10, 3, 0), size1000, 3))))) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: get map statuses for merged shuffle block chunks") { + conf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set(IS_TESTING, true) + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) + val slaveTracker = new MapOutputTrackerWorker(conf) + slaveTracker.trackerEndpoint = + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 4, 1) + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val blockMgrId = BlockManagerId("a", "hostA", 1000) + masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0)) + masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1)) + masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2)) + masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3)) + + val chunkBitmap = new RoaringBitmap() + chunkBitmap.add(0) + chunkBitmap.add(2) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap).toSeq === + Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0), + (ShuffleBlockId(10, 2, 0), size1000, 2)))) + ) + masterTracker.stop() + slaveTracker.stop() + rpcEnv.shutdown() + slaveRpcEnv.shutdown() + } + + test("SPARK-32921: getPreferredLocationsForShuffle with MergeStatus") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + sc = new SparkContext("local", "test", conf.clone()) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 5 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + // on hostB with output size 3 + // on hostC with output size 1 + // on hostC with output size 1 + tracker.registerShuffle(10, 6, 1) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 5)) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L), 6)) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 7)) + tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L), 8)) + tracker.registerMapOutput(10, 4, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 9)) + tracker.registerMapOutput(10, 5, MapStatus(BlockManagerId("c", "hostC", 1000), + Array(1L), 10)) + + val rdd = sc.parallelize(1 to 6, 6).map(num => (num, num).asInstanceOf[Product2[Int, Int]]) + val mockShuffleDep = mock(classOf[ShuffleDependency[Int, Int, _]]) + when(mockShuffleDep.shuffleId).thenReturn(10) + when(mockShuffleDep.partitioner).thenReturn(new HashPartitioner(1)) + when(mockShuffleDep.rdd).thenReturn(rdd) + + // Prepare a MergeStatus that merges 4 out of 5 blocks + val bitmap80 = new RoaringBitmap() + bitmap80.add(0) + bitmap80.add(1) + bitmap80.add(2) + bitmap80.add(3) + bitmap80.add(4) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap80, 11)) + + val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs1.nonEmpty) + assert(preferredLocs1.length === 1) + assert(preferredLocs1.head === "hostA") + + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000)) + // Prepare another MergeStatus that merges only 1 out of 5 blocks + val bitmap20 = new RoaringBitmap() + bitmap20.add(0) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap20, 2)) + + val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0) + assert(preferredLocs2.nonEmpty) + assert(preferredLocs2.length === 2) + assert(preferredLocs2 === Seq("hostA", "hostB")) + + tracker.stop() + rpcEnv.shutdown() + } + test("SPARK-34939: remote fetch using broadcast if broadcasted value is destroyed") { val newConf = new SparkConf newConf.set(RPC_MESSAGE_MAX_SIZE, 1) @@ -346,7 +561,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { rpcEnv.stop(masterTracker.trackerEndpoint) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - masterTracker.registerShuffle(20, 100) + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) @@ -368,9 +583,85 @@ class MapOutputTrackerSuite extends SparkFunSuite { shuffleStatus.cachedSerializedBroadcast.destroy(true) } val err = intercept[SparkException] { - MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf) + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf) + } + assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses")) + } + } + + test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") { + val newConf = new SparkConf + newConf.set(RPC_MESSAGE_MAX_SIZE, 1) + newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast + newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize + newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true) + newConf.set(IS_TESTING, true) + + // needs TorrentBroadcast so need a SparkContext + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + val bitmap1 = new RoaringBitmap() + bitmap1.add(1) + + masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) } - assert(err.getMessage.contains("Unable to deserialize broadcasted map statuses")) + masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000), + bitmap1, 1000L)) + + val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + val fetchedBytes = mapWorkerTracker.trackerEndpoint + .askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20)) + assert(masterTracker.getNumAvailableMergeResults(20) == 1) + assert(masterTracker.getNumAvailableOutputs(20) == 100) + + val mapOutput = + MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf) + val mergeOutput = + MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf) + assert(mapOutput.length == 100) + assert(mergeOutput.length == 1) + mapWorkerTracker.stop() + masterTracker.stop() } } + + test("SPARK-32921: unregister merge result if it is present and contains the map Id") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 4, 2) + assert(tracker.containsShuffle(10)) + val bitmap1 = new RoaringBitmap() + bitmap1.add(0) + bitmap1.add(1) + tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000), + bitmap1, 1000L)) + + val bitmap2 = new RoaringBitmap() + bitmap2.add(5) + bitmap2.add(6) + tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000), + bitmap2, 1000L)) + assert(tracker.getNumAvailableMergeResults(10) == 2) + tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(1)) + assert(tracker.getNumAvailableMergeResults(10) == 1) + tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(5)) + assert(tracker.getNumAvailableMergeResults(10) == 0) + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala index e433f42900..d808823987 100644 --- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark import org.apache.spark.benchmark.Benchmark import org.apache.spark.benchmark.BenchmarkBase -import org.apache.spark.scheduler.CompressedMapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, MergeStatus} import org.apache.spark.storage.BlockManagerId /** @@ -50,7 +50,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { val shuffleId = 10 - tracker.registerShuffle(shuffleId, numMaps) + tracker.registerShuffle(shuffleId, numMaps, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val r = new scala.util.Random(912) (0 until numMaps).foreach { i => tracker.registerMapOutput(shuffleId, i, @@ -66,7 +66,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { var serializedMapStatusSizes = 0 var serializedBroadcastSizes = 0 - val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses( + val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeOutputStatuses( shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) serializedMapStatusSizes = serializedMapStatus.length @@ -75,12 +75,12 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase { } benchmark.addCase("Serialization") { _ => - MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, + MapOutputTracker.serializeOutputStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize, sc.getConf) } benchmark.addCase("Deserialization") { _ => - val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf) + val result = MapOutputTracker.deserializeOutputStatuses(serializedMapStatus, sc.getConf) assert(result.length == numMaps) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 56684d9b03..126faec334 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File import java.util.{Locale, Properties} -import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } import scala.collection.JavaConverters._ @@ -33,7 +33,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} @@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, shuffleDep) - mapTrackerMaster.registerShuffle(0, 1) + mapTrackerMaster.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) // first attempt -- its successful val context1 = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 055ee0debe..707e1684f7 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} -import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated} +import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, MergeStatus, SparkListenerBlockUpdated} import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} @@ -1956,7 +1956,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent) Files.write(bm2.diskBlockManager.getFile(shuffleIndex2).toPath(), shuffleIndexBlockContent) - mapOutputTracker.registerShuffle(0, 1) + mapOutputTracker.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES) val decomManager = new BlockManagerDecommissioner(conf, bm1) try { mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0)) diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 7534268a0a..1b1bdf56f7 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -15,6 +15,7 @@ apacheds-i18n/2.0.0-M15//apacheds-i18n-2.0.0-M15.jar apacheds-kerberos-codec/2.0.0-M15//apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api/1.0.0-M20//api-asn1-api-1.0.0-M20.jar api-util/1.0.0-M20//api-util-1.0.0-M20.jar +arpack/1.3.2//arpack-1.3.2.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar arrow-format/2.0.0//arrow-format-2.0.0.jar arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar @@ -25,6 +26,7 @@ automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.10.2//avro-ipc-1.10.2.jar avro-mapred/1.10.2//avro-mapred-1.10.2.jar avro/1.10.2//avro-1.10.2.jar +blas/1.3.2//blas-1.3.2.jar bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar breeze_2.12/1.0//breeze_2.12-1.0.jar @@ -173,6 +175,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar +lapack/1.3.2//lapack-1.3.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index a86b4832b3..d5d0890c32 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -10,6 +10,7 @@ annotations/17.0.0//annotations-17.0.0.jar antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar antlr4-runtime/4.8-1//antlr4-runtime-4.8-1.jar aopalliance-repackaged/2.6.1//aopalliance-repackaged-2.6.1.jar +arpack/1.3.2//arpack-1.3.2.jar arpack_combined_all/0.1//arpack_combined_all-0.1.jar arrow-format/2.0.0//arrow-format-2.0.0.jar arrow-memory-core/2.0.0//arrow-memory-core-2.0.0.jar @@ -20,6 +21,7 @@ automaton/1.11-8//automaton-1.11-8.jar avro-ipc/1.10.2//avro-ipc-1.10.2.jar avro-mapred/1.10.2//avro-mapred-1.10.2.jar avro/1.10.2//avro-1.10.2.jar +blas/1.3.2//blas-1.3.2.jar bonecp/0.8.0.RELEASE//bonecp-0.8.0.RELEASE.jar breeze-macros_2.12/1.0//breeze-macros_2.12-1.0.jar breeze_2.12/1.0//breeze_2.12-1.0.jar @@ -144,6 +146,7 @@ kubernetes-model-policy/5.3.0//kubernetes-model-policy-5.3.0.jar kubernetes-model-rbac/5.3.0//kubernetes-model-rbac-5.3.0.jar kubernetes-model-scheduling/5.3.0//kubernetes-model-scheduling-5.3.0.jar kubernetes-model-storageclass/5.3.0//kubernetes-model-storageclass-5.3.0.jar +lapack/1.3.2//lapack-1.3.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.12.0//libthrift-0.12.0.jar diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 896ed77c87..22925a54a1 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -288,4 +288,4 @@ Here is the documentation on the standard connectors both from Apache and the cl * [The Azure Blob Filesystem driver (ABFS)](https://docs.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-abfs-driver) * IBM Cloud Object Storage connector for Apache Spark: [Stocator](https://github.com/CODAIT/stocator), [IBM Object Storage](https://www.ibm.com/cloud/object-storage). From IBM. - +* [Using JindoFS SDK to access Alibaba Cloud OSS](https://github.com/aliyun/alibabacloud-jindofs). diff --git a/docs/ml-linalg-guide.md b/docs/ml-linalg-guide.md index 7390913634..719554af5a 100644 --- a/docs/ml-linalg-guide.md +++ b/docs/ml-linalg-guide.md @@ -82,7 +82,7 @@ WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeSys WARN BLAS: Failed to load implementation from:com.github.fommil.netlib.NativeRefBLAS ``` -If native libraries are not properly configured in the system, the Java implementation (f2jBLAS) will be used as fallback option. +If native libraries are not properly configured in the system, the Java implementation (javaBLAS) will be used as fallback option. ## Spark Configuration diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 1138afa45d..ee0933a30a 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -83,6 +83,8 @@ license: | - In Spark 3.2, the unit-to-unit interval literals like `INTERVAL '1-1' YEAR TO MONTH` are converted to ANSI interval types: `YearMonthIntervalType` or `DayTimeIntervalType`. In Spark 3.1 and earlier, such interval literals are converted to `CalendarIntervalType`. To restore the behavior before Spark 3.2, you can set `spark.sql.legacy.interval.enabled` to `true`. + - In Spark 3.2, Spark supports `DayTimeIntervalType` and `YearMonthIntervalType` as inputs and outputs of `TRANSFORM` clause in Hive `SERDE` mode, the behavior is different between Hive `SERDE` mode and `ROW FORMAT DELIMITED` mode when these two types are used as inputs. In Hive `SERDE` mode, `DayTimeIntervalType` column is converted to `HiveIntervalDayTime`, its string format is `[-]?d h:m:s.n`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?d h:m:s.n' DAY TO TIME`. In Hive `SERDE` mode, `YearMonthIntervalType` column is converted to `HiveIntervalYearMonth`, its string format is `[-]?y-m`, but in `ROW FORMAT DELIMITED` mode the format is `INTERVAL '[-]?y-m' YEAR TO MONTH`. + ## Upgrading from Spark SQL 3.0 to 3.1 - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`. diff --git a/docs/sql-ref-syntax-ddl-create-table-hiveformat.md b/docs/sql-ref-syntax-ddl-create-table-hiveformat.md index 11ec2f1d9e..b2f5957416 100644 --- a/docs/sql-ref-syntax-ddl-create-table-hiveformat.md +++ b/docs/sql-ref-syntax-ddl-create-table-hiveformat.md @@ -39,14 +39,6 @@ CREATE [ EXTERNAL ] TABLE [ IF NOT EXISTS ] table_identifier [ LOCATION path ] [ TBLPROPERTIES ( key1=val1, key2=val2, ... ) ] [ AS select_statement ] - -row_format: - : SERDE serde_class [ WITH SERDEPROPERTIES (k1=v1, k2=v2, ... ) ] - | DELIMITED [ FIELDS TERMINATED BY fields_terminated_char [ ESCAPED BY escaped_char ] ] - [ COLLECTION ITEMS TERMINATED BY collection_items_terminated_char ] - [ MAP KEYS TERMINATED BY map_key_terminated_char ] - [ LINES TERMINATED BY row_terminated_char ] - [ NULL DEFINED AS null_char ] ``` Note that, the clauses between the columns definition clause and the AS SELECT clause can come in @@ -82,50 +74,10 @@ as any order. For example, you can write COMMENT table_comment after TBLPROPERTI * **INTO num_buckets BUCKETS** Specifies buckets numbers, which is used in `CLUSTERED BY` clause. - -* **row_format** - - Use the `SERDE` clause to specify a custom SerDe for one table. Otherwise, use the `DELIMITED` clause to use the native SerDe and specify the delimiter, escape character, null character and so on. - -* **SERDE** - - Specifies a custom SerDe for one table. - -* **serde_class** - - Specifies a fully-qualified class name of a custom SerDe. - -* **SERDEPROPERTIES** - - A list of key-value pairs that is used to tag the SerDe definition. - -* **DELIMITED** - The `DELIMITED` clause can be used to specify the native SerDe and state the delimiter, escape character, null character and so on. - -* **FIELDS TERMINATED BY** - - Used to define a column separator. - -* **COLLECTION ITEMS TERMINATED BY** - - Used to define a collection item separator. - -* **MAP KEYS TERMINATED BY** - - Used to define a map key separator. - -* **LINES TERMINATED BY** - - Used to define a row separator. - -* **NULL DEFINED AS** - - Used to define the specific value for NULL. - -* **ESCAPED BY** +* **row_format** - Used for escape mechanism. + Specifies the row format for input and output. See [HIVE FORMAT](sql-ref-syntax-hive-format.html) for more syntax details. * **STORED AS** diff --git a/docs/sql-ref-syntax-hive-format.md b/docs/sql-ref-syntax-hive-format.md new file mode 100644 index 0000000000..8092e582d9 --- /dev/null +++ b/docs/sql-ref-syntax-hive-format.md @@ -0,0 +1,73 @@ +--- +layout: global +title: Hive Row Format +displayTitle: Hive Row Format +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +### Description + +Spark supports a Hive row format in `CREATE TABLE` and `TRANSFORM` clause to specify serde or text delimiter. +There are two ways to define a row format in `row_format` of `CREATE TABLE` and `TRANSFORM` clauses. + 1. `SERDE` clause to specify a custom SerDe class. + 2. `DELIMITED` clause to specify a delimiter, an escape character, a null character, and so on for the native SerDe. + +### Syntax + +```sql +row_format: + SERDE serde_class [ WITH SERDEPROPERTIES (k1=v1, k2=v2, ... ) ] + | DELIMITED [ FIELDS TERMINATED BY fields_terminated_char [ ESCAPED BY escaped_char ] ] + [ COLLECTION ITEMS TERMINATED BY collection_items_terminated_char ] + [ MAP KEYS TERMINATED BY map_key_terminated_char ] + [ LINES TERMINATED BY row_terminated_char ] + [ NULL DEFINED AS null_char ] +``` + +### Parameters + +* **SERDE serde_class** + + Specifies a fully-qualified class name of custom SerDe. + +* **SERDEPROPERTIES** + + A list of key-value pairs that is used to tag the SerDe definition. + +* **FIELDS TERMINATED BY** + + Used to define a column separator. + +* **COLLECTION ITEMS TERMINATED BY** + + Used to define a collection item separator. + +* **MAP KEYS TERMINATED BY** + + Used to define a map key separator. + +* **LINES TERMINATED BY** + + Used to define a row separator. + +* **NULL DEFINED AS** + + Used to define the specific value for NULL. + +* **ESCAPED BY** + + Used for escape mechanism. diff --git a/docs/sql-ref-syntax-qry-select-groupby.md b/docs/sql-ref-syntax-qry-select-groupby.md index b81a5e43d5..d7827f8880 100644 --- a/docs/sql-ref-syntax-qry-select-groupby.md +++ b/docs/sql-ref-syntax-qry-select-groupby.md @@ -24,8 +24,8 @@ license: | The `GROUP BY` clause is used to group the rows based on a set of specified grouping expressions and compute aggregations on the group of rows based on one or more specified aggregate functions. Spark also supports advanced aggregations to do multiple aggregations for the same input record set via `GROUPING SETS`, `CUBE`, `ROLLUP` clauses. -The grouping expressions and advanced aggregations can be mixed in the `GROUP BY` clause. -See more details in the `Mixed Grouping Analytics` section. When a FILTER clause is attached to +The grouping expressions and advanced aggregations can be mixed in the `GROUP BY` clause and nested in a `GROUPING SETS` clause. +See more details in the `Mixed/Nested Grouping Analytics` section. When a FILTER clause is attached to an aggregate function, only the matching rows are passed to that function. ### Syntax @@ -95,13 +95,17 @@ aggregate_name ( [ DISTINCT ] expression [ , ... ] ) [ FILTER ( WHERE boolean_ex (product, warehouse, location), (warehouse), (product), (warehouse, product), ())`. The N elements of a `CUBE` specification results in 2^N `GROUPING SETS`. -* **Mixed Grouping Analytics** +* **Mixed/Nested Grouping Analytics** - A GROUP BY clause can include multiple `group_expression`s and multiple `CUBE|ROLLUP|GROUPING SETS`s. + A GROUP BY clause can include multiple `group_expression`s and multiple `CUBE|ROLLUP|GROUPING SETS`s. + `GROUPING SETS` can also have nested `CUBE|ROLLUP|GROUPING SETS` clauses, e.g. + `GROUPING SETS(ROLLUP(warehouse, location), CUBE(warehouse, location))`, + `GROUPING SETS(warehouse, GROUPING SETS(location, GROUPING SETS(ROLLUP(warehouse, location), CUBE(warehouse, location))))`. `CUBE|ROLLUP` is just a syntax sugar for `GROUPING SETS`, please refer to the sections above for how to translate `CUBE|ROLLUP` to `GROUPING SETS`. `group_expression` can be treated as a single-group `GROUPING SETS` under this context. For multiple `GROUPING SETS` in the `GROUP BY` clause, we generate - a single `GROUPING SETS` by doing a cross-product of the original `GROUPING SETS`s. For example, + a single `GROUPING SETS` by doing a cross-product of the original `GROUPING SETS`s. For nested `GROUPING SETS` in the `GROUPING SETS` clause, + we simply take its grouping sets and strip it. For example, `GROUP BY warehouse, GROUPING SETS((product), ()), GROUPING SETS((location, size), (location), (size), ())` and `GROUP BY warehouse, ROLLUP(product), CUBE(location, size)` is equivalent to `GROUP BY GROUPING SETS( @@ -113,7 +117,10 @@ aggregate_name ( [ DISTINCT ] expression [ , ... ] ) [ FILTER ( WHERE boolean_ex (warehouse, location), (warehouse, size), (warehouse))`. - + + `GROUP BY GROUPING SETS(GROUPING SETS(warehouse), GROUPING SETS((warehouse, product)))` is equivalent to + `GROUP BY GROUPING SETS((warehouse), (warehouse, product))`. + * **aggregate_name** Specifies an aggregate function name (MIN, MAX, COUNT, SUM, AVG, etc.). diff --git a/docs/sql-ref-syntax-qry-select-transform.md b/docs/sql-ref-syntax-qry-select-transform.md index 814bd01ec2..21966f2e1c 100644 --- a/docs/sql-ref-syntax-qry-select-transform.md +++ b/docs/sql-ref-syntax-qry-select-transform.md @@ -33,14 +33,6 @@ SELECT TRANSFORM ( expression [ , ... ] ) USING command_or_script [ AS ( [ col_name [ col_type ] ] [ , ... ] ) ] [ ROW FORMAT row_format ] [ RECORDREADER record_reader_class ] - -row_format: - SERDE serde_class [ WITH SERDEPROPERTIES (k1=v1, k2=v2, ... ) ] - | DELIMITED [ FIELDS TERMINATED BY fields_terminated_char [ ESCAPED BY escaped_char ] ] - [ COLLECTION ITEMS TERMINATED BY collection_items_terminated_char ] - [ MAP KEYS TERMINATED BY map_key_terminated_char ] - [ LINES TERMINATED BY row_terminated_char ] - [ NULL DEFINED AS null_char ] ``` ### Parameters @@ -49,45 +41,9 @@ row_format: Specifies a combination of one or more values, operators and SQL functions that results in a value. -* **row_format** - - Otherwise, uses the `DELIMITED` clause to specify the native SerDe and state the delimiter, escape character, null character and so on. - -* **SERDE** - - Specifies a custom SerDe for one table. - -* **serde_class** - - Specifies a fully-qualified class name of a custom SerDe. - -* **DELIMITED** - - The `DELIMITED` clause can be used to specify the native SerDe and state the delimiter, escape character, null character and so on. - -* **FIELDS TERMINATED BY** - - Used to define a column separator. - -* **COLLECTION ITEMS TERMINATED BY** - - Used to define a collection item separator. - -* **MAP KEYS TERMINATED BY** - - Used to define a map key separator. - -* **LINES TERMINATED BY** - - Used to define a row separator. - -* **NULL DEFINED AS** - - Used to define the specific value for NULL. - -* **ESCAPED BY** +* **row_format** - Used for escape mechanism. + Specifies the row format for input and output. See [HIVE FORMAT](sql-ref-syntax-hive-format.html) for more syntax details. * **RECORDWRITER** diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 52aa5a6973..424526eafd 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[avro] class AvroOutputWriter( - path: String, + val path: String, context: TaskAttemptContext, schema: StructType, avroSchema: Schema) extends OutputWriter { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index 8b37fd6e7e..b6d64c79b1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -22,6 +22,7 @@ import java.{util => ju} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer @@ -105,4 +106,16 @@ private case class KafkaBatchPartitionReader( range } } + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val offsetOutOfRange = new CustomTaskMetric { + override def name(): String = "offsetOutOfRange" + override def value(): Long = consumer.getNumOffsetOutOfRange() + } + val dataLoss = new CustomTaskMetric { + override def name(): String = "dataLoss" + override def value(): Long = consumer.getNumDataLoss() + } + Array(offsetOutOfRange, dataLoss) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 7299b182ae..5c772abfe8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.kafka010.KafkaConfigUpdater import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric} import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, SupportsTruncate, WriteBuilder} @@ -503,9 +504,23 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister startingStreamOffsets, failOnDataLoss(caseInsensitiveOptions)) } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new OffsetOutOfRangeMetric, new DataLossMetric) + } } } +private[spark] class OffsetOutOfRangeMetric extends CustomSumMetric { + override def name(): String = "offsetOutOfRange" + override def description(): String = "estimated number of fetched offsets out of range" +} + +private[spark] class DataLossMetric extends CustomSumMetric { + override def name(): String = "dataLoss" + override def description(): String = "number of data loss error" +} + private[kafka010] object KafkaSourceProvider extends Logging { private val ASSIGN = "assign" private val SUBSCRIBE_PATTERN = "subscribepattern" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala index 5c92d110a6..37fe38ea94 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala @@ -239,6 +239,9 @@ private[kafka010] class KafkaDataConsumer( fetchedDataPool: FetchedDataPool) extends Logging { import KafkaDataConsumer._ + private var offsetOutOfRange = 0L + private var dataLoss = 0L + private val isTokenProviderEnabled = HadoopDelegationTokenManager.isServiceEnabled(SparkEnv.get.conf, "kafka") @@ -329,7 +332,14 @@ private[kafka010] class KafkaDataConsumer( reportDataLoss(topicPartition, groupId, failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e) + + val oldToFetchOffsetd = toFetchOffset toFetchOffset = getEarliestAvailableOffsetBetween(consumer, toFetchOffset, untilOffset) + if (toFetchOffset == UNKNOWN_OFFSET) { + offsetOutOfRange += (untilOffset - oldToFetchOffsetd) + } else { + offsetOutOfRange += (toFetchOffset - oldToFetchOffsetd) + } } } @@ -350,6 +360,9 @@ private[kafka010] class KafkaDataConsumer( consumer.getAvailableOffsetRange() } + def getNumOffsetOutOfRange(): Long = offsetOutOfRange + def getNumDataLoss(): Long = dataLoss + /** * Release borrowed objects in data reader to the pool. Once the instance is created, caller * must call method after using the instance to make sure resources are not leaked. @@ -596,6 +609,7 @@ private[kafka010] class KafkaDataConsumer( message: String, cause: Throwable = null): Unit = { val finalMessage = s"$message ${additionalMessage(topicPartition, groupId, failOnDataLoss)}" + dataLoss += 1 reportDataLoss0(failOnDataLoss, finalMessage, cause) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index fc911ed9ac..058563dfa1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -575,11 +575,11 @@ class KafkaTestUtils( s"topic $topic still exists in the replica manager") // ensure that logs from all replicas are deleted if delete topic is marked successful assert(servers.forall(server => topicAndPartitions.forall(tp => - server.getLogManager().getLog(tp).isEmpty)), + server.getLogManager.getLog(tp).isEmpty)), s"topic $topic still exists in log manager") // ensure that topic is removed from all cleaner offsets assert(servers.forall(server => topicAndPartitions.forall { tp => - val checkpoints = server.getLogManager().liveLogDirs.map { logDir => + val checkpoints = server.getLogManager.liveLogDirs.map { logDir => new OffsetCheckpointFile(new File(logDir, "cleaner-offset-checkpoint")).read() } checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) diff --git a/graphx/pom.xml b/graphx/pom.xml index 3ed68c0652..c4fa38a1dc 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -60,9 +60,8 @@ guava - com.github.fommil.netlib - core - ${netlib.java.version} + dev.ludovic.netlib + blas net.sourceforge.f2j diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index db786a194e..d7099c5c95 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -19,9 +19,8 @@ package org.apache.spark.graphx.lib import scala.util.Random -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.graphx._ +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.rdd._ /** Implementation of SVD++ algorithm. */ @@ -102,22 +101,22 @@ object SVDPlusPlus { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) val rank = p.length - var pred = u + usr._3 + itm._3 + blas.ddot(rank, q, 1, usr._2, 1) + var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(rank, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = ctx.attr - pred // updateP = (err * q - conf.gamma7 * p) * conf.gamma2 val updateP = q.clone() - blas.dscal(rank, err * conf.gamma2, updateP, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) + BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateP, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, p, 1, updateP, 1) // updateQ = (err * usr._2 - conf.gamma7 * q) * conf.gamma2 val updateQ = usr._2.clone() - blas.dscal(rank, err * conf.gamma2, updateQ, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) + BLAS.nativeBLAS.dscal(rank, err * conf.gamma2, updateQ, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, q, 1, updateQ, 1) // updateY = (err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2 val updateY = q.clone() - blas.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) - blas.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) + BLAS.nativeBLAS.dscal(rank, err * usr._4 * conf.gamma2, updateY, 1) + BLAS.nativeBLAS.daxpy(rank, -conf.gamma7 * conf.gamma2, itm._2, 1, updateY, 1) ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } @@ -129,7 +128,7 @@ object SVDPlusPlus { ctx => ctx.sendToSrc(ctx.dstAttr._2), (g1, g2) => { val out = g1.clone() - blas.daxpy(out.length, 1.0, g2, 1, out, 1) + BLAS.nativeBLAS.daxpy(out.length, 1.0, g2, 1, out, 1) out }) val gJoinT1 = g.outerJoinVertices(t1) { @@ -137,7 +136,7 @@ object SVDPlusPlus { msg: Option[Array[Double]]) => if (msg.isDefined) { val out = vd._1.clone() - blas.daxpy(out.length, vd._4, msg.get, 1, out, 1) + BLAS.nativeBLAS.daxpy(out.length, vd._4, msg.get, 1, out, 1) (vd._1, out, vd._3, vd._4) } else { vd @@ -154,9 +153,9 @@ object SVDPlusPlus { (g1: (Array[Double], Array[Double], Double), g2: (Array[Double], Array[Double], Double)) => { val out1 = g1._1.clone() - blas.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) + BLAS.nativeBLAS.daxpy(out1.length, 1.0, g2._1, 1, out1, 1) val out2 = g2._2.clone() - blas.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) + BLAS.nativeBLAS.daxpy(out2.length, 1.0, g2._2, 1, out2, 1) (out1, out2, g1._3 + g2._3) }) val gJoinT2 = g.outerJoinVertices(t2) { @@ -164,9 +163,9 @@ object SVDPlusPlus { vd: (Array[Double], Array[Double], Double, Double), msg: Option[(Array[Double], Array[Double], Double)]) => { val out1 = vd._1.clone() - blas.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) + BLAS.nativeBLAS.daxpy(out1.length, 1.0, msg.get._1, 1, out1, 1) val out2 = vd._2.clone() - blas.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) + BLAS.nativeBLAS.daxpy(out2.length, 1.0, msg.get._2, 1, out2, 1) (out1, out2, vd._3 + msg.get._3, vd._4) } }.cache() @@ -180,7 +179,7 @@ object SVDPlusPlus { (ctx: EdgeContext[(Array[Double], Array[Double], Double, Double), Double, Double]): Unit = { val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) - var pred = u + usr._3 + itm._3 + blas.ddot(q.length, q, 1, usr._2, 1) + var pred = u + usr._3 + itm._3 + BLAS.nativeBLAS.ddot(q.length, q, 1, usr._2, 1) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) val err = (ctx.attr - pred) * (ctx.attr - pred) diff --git a/licenses-binary/LICENSE-blas.txt b/licenses-binary/LICENSE-blas.txt new file mode 100644 index 0000000000..2b8bec28b0 --- /dev/null +++ b/licenses-binary/LICENSE-blas.txt @@ -0,0 +1,25 @@ +MIT License +----------- + +Copyright 2020, 2021, Ludovic Henry + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Please contact git@ludovic.dev or visit ludovic.dev if you need additional +information or have any questions. diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 16cd55c8a4..a977ae3c10 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -75,6 +75,11 @@ test-jar test + + + dev.ludovic.netlib + blas + @@ -88,34 +93,6 @@ - - jvm-vectorized - - src/jvm-vectorized/java - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-vectorized-sources - generate-sources - - add-source - - - - ${extra.source.dir} - - - - - - - - target/scala-${scala.binary.version}/classes diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 518c71129a..5a6bee3e74 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import dev.ludovic.netlib.{BLAS => NetlibBLAS, + JavaBLAS => NetlibJavaBLAS, + NativeBLAS => NetlibNativeBLAS} /** * BLAS routines for MLlib's vectors and matrices. @@ -29,38 +31,23 @@ private[spark] object BLAS extends Serializable { private val nativeL1Threshold: Int = 256 // For level-1 function dspmv, use javaBLAS for better performance. - private[ml] def javaBLAS: NetlibBLAS = { + private[spark] def javaBLAS: NetlibBLAS = { if (_javaBLAS == null) { - _javaBLAS = - try { - // scalastyle:off classforname - Class.forName("org.apache.spark.ml.linalg.VectorizedBLAS", true, - Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader)) - .newInstance() - .asInstanceOf[NetlibBLAS] - // scalastyle:on classforname - } catch { - case _: Throwable => new F2jBLAS - } + _javaBLAS = NetlibJavaBLAS.getInstance } _javaBLAS } // For level-3 routines, we use the native BLAS. - private[ml] def nativeBLAS: NetlibBLAS = { + private[spark] def nativeBLAS: NetlibBLAS = { if (_nativeBLAS == null) { _nativeBLAS = - if (NetlibBLAS.getInstance.isInstanceOf[F2jBLAS]) { - javaBLAS - } else { - NetlibBLAS.getInstance - } + try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } } _nativeBLAS } - private[ml] def getBLAS(vectorSize: Int): NetlibBLAS = { + private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = { if (vectorSize < nativeL1Threshold) { javaBLAS } else { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala index 1dcfcf9ebb..144f59ac17 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASBenchmark.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} +import dev.ludovic.netlib.blas.NetlibF2jBLAS +import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} @@ -38,48 +39,66 @@ object BLASBenchmark extends BenchmarkBase { val iters = 1e2.toInt val rnd = new scala.util.Random(0) - val f2jBLAS = new F2jBLAS - val nativeBLAS = NetlibBLAS.getInstance - val vectorBLAS = - try { - // scalastyle:off classforname - Class.forName("org.apache.spark.ml.linalg.VectorizedBLAS", true, - Option(Thread.currentThread().getContextClassLoader) - .getOrElse(getClass.getClassLoader)) - .newInstance() - .asInstanceOf[NetlibBLAS] - // scalastyle:on classforname - } catch { - case _: Throwable => new F2jBLAS - } + val f2jBLAS = NetlibF2jBLAS.getInstance + val javaBLAS = BLAS.javaBLAS + val nativeBLAS = BLAS.nativeBLAS // scalastyle:off println - println("nativeBLAS = " + nativeBLAS.getClass.getName) println("f2jBLAS = " + f2jBLAS.getClass.getName) - println("vectorBLAS = " + vectorBLAS.getClass.getName) + println("javaBLAS = " + javaBLAS.getClass.getName) + println("nativeBLAS = " + nativeBLAS.getClass.getName) // scalastyle:on println runBenchmark("daxpy") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("daxpy", n, iters, output = output) + val benchmark = new Benchmark("daxpy", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.daxpy(n, alpha, x, 1, y, 1) + f2jBLAS.daxpy(n, alpha, x, 1, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { + benchmark.addCase("java") { _ => + javaBLAS.daxpy(n, alpha, x, 1, y.clone, 1) + } + + if (nativeBLAS != javaBLAS) { benchmark.addCase("native") { _ => - nativeBLAS.daxpy(n, alpha, x, 1, y, 1) + nativeBLAS.daxpy(n, alpha, x, 1, y.clone, 1) } } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.daxpy(n, alpha, x, 1, y, 1) + benchmark.run() + } + + runBenchmark("saxpy") { + val n = 1e8.toInt + val alpha = rnd.nextFloat + val x = Array.fill(n) { rnd.nextFloat } + val y = Array.fill(n) { rnd.nextFloat } + + val benchmark = new Benchmark("saxpy", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.saxpy(n, alpha, x, 1, y.clone, 1) + } + + benchmark.addCase("java") { _ => + javaBLAS.saxpy(n, alpha, x, 1, y.clone, 1) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.saxpy(n, alpha, x, 1, y.clone, 1) } } @@ -87,25 +106,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("ddot") { - val n = 1e7.toInt + val n = 1e8.toInt val x = Array.fill(n) { rnd.nextDouble } val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("ddot", n, iters, output = output) + val benchmark = new Benchmark("ddot", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => f2jBLAS.ddot(n, x, 1, y, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.ddot(n, x, 1, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.ddot(n, x, 1, y, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.ddot(n, x, 1, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.ddot(n, x, 1, y, 1) } } @@ -113,25 +133,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sdot") { - val n = 1e7.toInt + val n = 1e8.toInt val x = Array.fill(n) { rnd.nextFloat } val y = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sdot", n, iters, output = output) + val benchmark = new Benchmark("sdot", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => f2jBLAS.sdot(n, x, 1, y, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sdot(n, x, 1, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sdot(n, x, 1, y, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sdot(n, x, 1, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sdot(n, x, 1, y, 1) } } @@ -139,25 +160,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dscal") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dscal", n, iters, output = output) + val benchmark = new Benchmark("dscal", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dscal(n, alpha, x, 1) + f2jBLAS.dscal(n, alpha, x.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dscal(n, alpha, x, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dscal(n, alpha, x.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dscal(n, alpha, x, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dscal(n, alpha, x.clone, 1) } } @@ -165,25 +187,26 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sscal") { - val n = 1e7.toInt + val n = 1e8.toInt val alpha = rnd.nextFloat val x = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sscal", n, iters, output = output) + val benchmark = new Benchmark("sscal", n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sscal(n, alpha, x, 1) + f2jBLAS.sscal(n, alpha, x.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sscal(n, alpha, x, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sscal(n, alpha, x.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sscal(n, alpha, x, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sscal(n, alpha, x.clone, 1) } } @@ -191,28 +214,29 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dspmv[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(n * (n + 1) / 2) { rnd.nextDouble } val x = Array.fill(n) { rnd.nextDouble } val beta = rnd.nextDouble val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dspmv[U]", n, iters, output = output) + val benchmark = new Benchmark("dspmv[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) + f2jBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dspmv("U", n, alpha, a, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dspmv("U", n, alpha, a, x, 1, beta, y.clone, 1) } } @@ -220,26 +244,27 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dspr[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val a = Array.fill(n * (n + 1) / 2) { rnd.nextDouble } - val benchmark = new Benchmark("dspr[U]", n, iters, output = output) + val benchmark = new Benchmark("dspr[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dspr("U", n, alpha, x, 1, a) + f2jBLAS.dspr("U", n, alpha, x, 1, a.clone) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dspr("U", n, alpha, x, 1, a) - } + benchmark.addCase("java") { _ => + javaBLAS.dspr("U", n, alpha, x, 1, a.clone) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dspr("U", n, alpha, x, 1, a) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dspr("U", n, alpha, x, 1, a.clone) } } @@ -247,26 +272,27 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dsyr[U]") { - val n = 1e4.toInt + val n = 1e3.toInt val alpha = rnd.nextDouble val x = Array.fill(n) { rnd.nextDouble } val a = Array.fill(n * n) { rnd.nextDouble } - val benchmark = new Benchmark("dsyr[U]", n, iters, output = output) + val benchmark = new Benchmark("dsyr[U]", n * (n + 1) / 2, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dsyr("U", n, alpha, x, 1, a, n) + f2jBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dsyr("U", n, alpha, x, 1, a, n) - } + benchmark.addCase("java") { _ => + javaBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dsyr("U", n, alpha, x, 1, a, n) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dsyr("U", n, alpha, x, 1, a.clone, n) } } @@ -274,7 +300,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dgemv[N]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * n) { rnd.nextDouble } @@ -283,21 +309,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextDouble val y = Array.fill(m) { rnd.nextDouble } - val benchmark = new Benchmark("dgemv[N]", n, iters, output = output) + val benchmark = new Benchmark("dgemv[N]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -305,7 +332,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("dgemv[T]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * n) { rnd.nextDouble } @@ -314,21 +341,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextDouble val y = Array.fill(n) { rnd.nextDouble } - val benchmark = new Benchmark("dgemv[T]", n, iters, output = output) + val benchmark = new Benchmark("dgemv[T]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -336,7 +364,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sgemv[N]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextFloat val a = Array.fill(m * n) { rnd.nextFloat } @@ -345,21 +373,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextFloat val y = Array.fill(m) { rnd.nextFloat } - val benchmark = new Benchmark("sgemv[N]", n, iters, output = output) + val benchmark = new Benchmark("sgemv[N]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemv("N", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -367,7 +396,7 @@ object BLASBenchmark extends BenchmarkBase { } runBenchmark("sgemv[T]") { - val m = 1e4.toInt + val m = 1e3.toInt val n = 1e3.toInt val alpha = rnd.nextFloat val a = Array.fill(m * n) { rnd.nextFloat } @@ -376,21 +405,22 @@ object BLASBenchmark extends BenchmarkBase { val beta = rnd.nextFloat val y = Array.fill(n) { rnd.nextFloat } - val benchmark = new Benchmark("sgemv[T]", n, iters, output = output) + val benchmark = new Benchmark("sgemv[T]", m * n, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + f2jBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) - } + benchmark.addCase("java") { _ => + javaBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemv("T", m, n, alpha, a, lda, x, 1, beta, y.clone, 1) } } @@ -399,7 +429,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[N,N]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -410,21 +440,22 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[N,N]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[N,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } @@ -433,7 +464,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[N,T]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -444,21 +475,22 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[N,T]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[N,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("native") { _ => - nativeBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) - } + benchmark.addCase("java") { _ => + javaBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } @@ -467,7 +499,7 @@ object BLASBenchmark extends BenchmarkBase { runBenchmark("dgemm[T,N]") { val m = 1e3.toInt - val n = 1e2.toInt + val n = 1e3.toInt val k = 1e3.toInt val alpha = rnd.nextDouble val a = Array.fill(m * k) { rnd.nextDouble } @@ -478,21 +510,197 @@ object BLASBenchmark extends BenchmarkBase { val c = Array.fill(m * n) { rnd.nextDouble } var ldc = m - val benchmark = new Benchmark("dgemm[T,N]", m*n, iters, output = output) + val benchmark = new Benchmark("dgemm[T,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("dgemm[T,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextDouble + val a = Array.fill(m * k) { rnd.nextDouble } + val lda = k + val b = Array.fill(k * n) { rnd.nextDouble } + val ldb = n + val beta = rnd.nextDouble + val c = Array.fill(m * n) { rnd.nextDouble } + var ldc = m + + val benchmark = new Benchmark("dgemm[T,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.dgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[N,N]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = m + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = k + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[N,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("N", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[N,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = m + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = n + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[N,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("N", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + } + + benchmark.run() + } + + runBenchmark("sgemm[T,N]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = k + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = k + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[T,N]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) benchmark.addCase("f2j") { _ => - f2jBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + f2jBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } - if (!nativeBLAS.getClass.equals(classOf[F2jBLAS])) { + if (nativeBLAS != javaBLAS) { benchmark.addCase("native") { _ => - nativeBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + nativeBLAS.sgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } - if (!vectorBLAS.getClass.equals(classOf[F2jBLAS])) { - benchmark.addCase("vector") { _ => - vectorBLAS.dgemm("T", "N", m, n, k, alpha, a, lda, b, ldb, beta, c, ldc) + benchmark.run() + } + + runBenchmark("sgemm[T,T]") { + val m = 1e3.toInt + val n = 1e3.toInt + val k = 1e3.toInt + val alpha = rnd.nextFloat + val a = Array.fill(m * k) { rnd.nextFloat } + val lda = k + val b = Array.fill(k * n) { rnd.nextFloat } + val ldb = n + val beta = rnd.nextFloat + val c = Array.fill(m * n) { rnd.nextFloat } + var ldc = m + + val benchmark = new Benchmark("sgemm[T,T]", m * n * k, iters, + warmupTime = 30.seconds, + minTime = 30.seconds, + output = output) + + benchmark.addCase("f2j") { _ => + f2jBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + benchmark.addCase("java") { _ => + javaBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) + } + + if (nativeBLAS != javaBLAS) { + benchmark.addCase("native") { _ => + nativeBLAS.sgemm("T", "T", m, n, k, alpha, a, lda, b, ldb, beta, c.clone, ldc) } } diff --git a/mllib/pom.xml b/mllib/pom.xml index f5b5a979e3..626ac85ce1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -142,6 +142,19 @@ test + + dev.ludovic.netlib + blas + + + dev.ludovic.netlib + lapack + + + dev.ludovic.netlib + arpack + + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 9191b3ec4b..9214f55130 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -222,6 +222,7 @@ class LinearSVC @Since("2.2.0") ( } val featuresStd = summarizer.std.toArray + val featuresMean = summarizer.mean.toArray val getFeaturesStd = (j: Int) => featuresStd(j) val regularization = if ($(regParam) != 0.0) { val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures @@ -239,7 +240,8 @@ class LinearSVC @Since("2.2.0") ( as a result, no scaling is needed. */ val (rawCoefficients, objectiveHistory) = - trainImpl(instances, actualBlockSizeInMB, featuresStd, regularization, optimizer) + trainImpl(instances, actualBlockSizeInMB, featuresStd, featuresMean, + regularization, optimizer) if (rawCoefficients == null) { val msg = s"${optimizer.getClass.getName} failed." @@ -277,16 +279,19 @@ class LinearSVC @Since("2.2.0") ( instances: RDD[Instance], actualBlockSizeInMB: Double, featuresStd: Array[Double], + featuresMean: Array[Double], regularization: Option[L2Regularization], optimizer: BreezeOWLQN[Int, BDV[Double]]): (Array[Double], Array[Double]) = { val numFeatures = featuresStd.length val numFeaturesPlusIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures - val bcFeaturesStd = instances.context.broadcast(featuresStd) + val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) + val scaledMean = Array.tabulate(numFeatures)(i => inverseStd(i) * featuresMean(i)) + val bcInverseStd = instances.context.broadcast(inverseStd) + val bcScaledMean = instances.context.broadcast(scaledMean) val standardized = instances.mapPartitions { iter => - val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 } - val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) + val func = StandardScalerModel.getTransformFunc(Array.empty, bcInverseStd.value, false, true) iter.map { case Instance(label, weight, vec) => Instance(label, weight, func(vec)) } } @@ -295,13 +300,24 @@ class LinearSVC @Since("2.2.0") ( .persist(StorageLevel.MEMORY_AND_DISK) .setName(s"training blocks (blockSizeInMB=$actualBlockSizeInMB)") - val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_) + val getAggregatorFunc = new HingeBlockAggregator(bcInverseStd, bcScaledMean, + $(fitIntercept))(_) val costFun = new RDDLossFunction(blocks, getAggregatorFunc, regularization, $(aggregationDepth)) - val states = optimizer.iterations(new CachedDiffFunction(costFun), - Vectors.zeros(numFeaturesPlusIntercept).asBreeze.toDenseVector) + val initialSolution = Array.ofDim[Double](numFeaturesPlusIntercept) + if ($(fitIntercept)) { + // orginal `initialSolution` is for problem: + // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept) + // we should adjust it to the initial solution for problem: + // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept) + // NOTE: this is NOOP before we finally support model initialization + val adapt = BLAS.javaBLAS.ddot(numFeatures, initialSolution, 1, scaledMean, 1) + initialSolution(numFeatures) += adapt + } + val states = optimizer.iterations(new CachedDiffFunction(costFun), + new BDV[Double](initialSolution)) val arrayBuilder = mutable.ArrayBuilder.make[Double] var state: optimizer.State = null while (states.hasNext) { @@ -309,9 +325,19 @@ class LinearSVC @Since("2.2.0") ( arrayBuilder += state.adjustedValue } blocks.unpersist() - bcFeaturesStd.destroy() - - (if (state != null) state.x.toArray else null, arrayBuilder.result) + bcInverseStd.destroy() + bcScaledMean.destroy() + + val solution = if (state == null) null else state.x.toArray + if ($(fitIntercept) && solution != null) { + // the final solution is for problem: + // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept) + // we should adjust it back for original problem: + // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept) + val adapt = BLAS.javaBLAS.ddot(numFeatures, solution, 1, scaledMean, 1) + solution(numFeatures) -= adapt + } + (solution, arrayBuilder.result) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 57fb46b451..c3c54651ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -982,14 +982,14 @@ class LogisticRegression @Since("1.2.0") ( val adapt = Array.ofDim[Double](numClasses) BLAS.javaBLAS.dgemv("N", numClasses, numFeatures, 1.0, initialSolution, numClasses, scaledMean, 1, 0.0, adapt, 1) - BLAS.getBLAS(numFeatures).daxpy(numClasses, 1.0, adapt, 0, 1, + BLAS.javaBLAS.daxpy(numClasses, 1.0, adapt, 0, 1, initialSolution, numClasses * numFeatures, 1) } else { - // orginal `initialCoefWithInterceptArray` is for problem: + // original `initialSolution` is for problem: // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept) // we should adjust it to the initial solution for problem: // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept) - val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, initialSolution, 1, scaledMean, 1) + val adapt = BLAS.javaBLAS.ddot(numFeatures, initialSolution, 1, scaledMean, 1) initialSolution(numFeatures) += adapt } } @@ -1018,14 +1018,14 @@ class LogisticRegression @Since("1.2.0") ( val adapt = Array.ofDim[Double](numClasses) BLAS.javaBLAS.dgemv("N", numClasses, numFeatures, 1.0, solution, numClasses, scaledMean, 1, 0.0, adapt, 1) - BLAS.getBLAS(numFeatures).daxpy(numClasses, -1.0, adapt, 0, 1, + BLAS.javaBLAS.daxpy(numClasses, -1.0, adapt, 0, 1, solution, numClasses * numFeatures, 1) } else { // the final solution is for problem: // y = f(w1 * (x1 - avg_x1) / std_x1, w2 * (x2 - avg_x2) / std_x2, ..., intercept) // we should adjust it back for original problem: // y = f(w1 * x1 / std_x1, w2 * x2 / std_x2, ..., intercept) - val adapt = BLAS.getBLAS(numFeatures).ddot(numFeatures, solution, 1, scaledMean, 1) + val adapt = BLAS.javaBLAS.ddot(numFeatures, solution, 1, scaledMean, 1) solution(numFeatures) -= adapt } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala index 091c885ca0..09a4335dad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/BinaryLogisticBlockAggregator.scala @@ -72,7 +72,7 @@ private[ml] class BinaryLogisticBlockAggregator( // deal with non-zero values in prediction. private val marginOffset = if (fitWithMean) { coefficientsArray.last - - BLAS.getBLAS(numFeatures).ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1) + BLAS.javaBLAS.ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1) } else { Double.NaN } @@ -142,7 +142,7 @@ private[ml] class BinaryLogisticBlockAggregator( case sm: SparseMatrix if fitIntercept => val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures)) BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec) - BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, + BLAS.javaBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, gradientSumArray, 1) case sm: SparseMatrix if !fitIntercept => @@ -156,7 +156,7 @@ private[ml] class BinaryLogisticBlockAggregator( if (fitWithMean) { // above update of the linear part of gradientSumArray does NOT take the centering // into account, here we need to adjust this part. - BLAS.getBLAS(numFeatures).daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1, + BLAS.javaBLAS.daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1, gradientSumArray, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala new file mode 100644 index 0000000000..f99c531c96 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregator.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.InstanceBlock +import org.apache.spark.ml.linalg._ + + +/** + * HingeBlockAggregator computes the gradient and loss for Huber loss function + * as used in linear regression for blocks in sparse or dense matrix in an online fashion. + * + * Two BlockHuberAggregators can be merged together to have a summary of loss and gradient + * of the corresponding joint dataset. + * + * NOTE: The feature values are expected to already have be scaled (multiplied by bcInverseStd, + * but NOT centered) before computation. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. When true, will perform data centering + * in a virtual way. Then we MUST adjust the intercept of both initial + * coefficients and final solution in the caller. + */ +private[ml] class HingeBlockAggregator( + bcInverseStd: Broadcast[Array[Double]], + bcScaledMean: Broadcast[Array[Double]], + fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[InstanceBlock, HingeBlockAggregator] + with Logging { + + if (fitIntercept) { + require(bcScaledMean != null && bcScaledMean.value.length == bcInverseStd.value.length, + "scaled means is required when center the vectors") + } + + private val numFeatures = bcInverseStd.value.length + protected override val dim: Int = bcCoefficients.value.size + + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " + + s"got type ${bcCoefficients.value.getClass}.)") + } + + @transient private lazy val linear = if (fitIntercept) { + new DenseVector(coefficientsArray.take(numFeatures)) + } else { + new DenseVector(coefficientsArray) + } + + // pre-computed margin of an empty vector. + // with this variable as an offset, for a sparse vector, we only need to + // deal with non-zero values in prediction. + private val marginOffset = if (fitIntercept) { + coefficientsArray.last - + BLAS.javaBLAS.ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1) + } else { + Double.NaN + } + + /** + * Add a new training instance block to this HingeBlockAggregator, and update the loss + * and gradient of the objective function. + * + * @param block The instance block of data point to be added. + * @return This HingeBlockAggregator object. + */ + def add(block: InstanceBlock): this.type = { + require(block.matrix.isTransposed) + require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " + + s"instance. Expecting $numFeatures but got ${block.numFeatures}.") + require(block.weightIter.forall(_ >= 0), + s"instance weights ${block.weightIter.mkString("[", ",", "]")} has to be >= 0.0") + + if (block.weightIter.forall(_ == 0)) return this + val size = block.size + + // vec/arr here represents margins + val vec = new DenseVector(Array.ofDim[Double](size)) + val arr = vec.values + if (fitIntercept) java.util.Arrays.fill(arr, marginOffset) + BLAS.gemv(1.0, block.matrix, linear, 1.0, vec) + + // in-place convert margins to multiplier + // then, vec/arr represents multiplier + var localLossSum = 0.0 + var localWeightSum = 0.0 + var multiplierSum = 0.0 + var i = 0 + while (i < size) { + val weight = block.getWeight(i) + localWeightSum += weight + if (weight > 0) { + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val label = block.getLabel(i) + val labelScaled = label + label - 1.0 + val loss = (1.0 - labelScaled * arr(i)) * weight + if (loss > 0) { + localLossSum += loss + val multiplier = -labelScaled * weight + arr(i) = multiplier + multiplierSum += multiplier + } else { arr(i) = 0.0 } + } else { arr(i) = 0.0 } + i += 1 + } + lossSum += localLossSum + weightSum += localWeightSum + + // predictions are all correct, no gradient signal + if (arr.forall(_ == 0)) return this + + // update the linear part of gradientSumArray + block.matrix match { + case dm: DenseMatrix => + BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols, + vec.values, 1, 1.0, gradientSumArray, 1) + + case sm: SparseMatrix if fitIntercept => + val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures)) + BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec) + BLAS.javaBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, + gradientSumArray, 1) + + case sm: SparseMatrix if !fitIntercept => + val gradSumVec = new DenseVector(gradientSumArray) + BLAS.gemv(1.0, sm.transpose, vec, 1.0, gradSumVec) + + case m => + throw new IllegalArgumentException(s"Unknown matrix type ${m.getClass}.") + } + + if (fitIntercept) { + // above update of the linear part of gradientSumArray does NOT take the centering + // into account, here we need to adjust this part. + BLAS.javaBLAS.daxpy(numFeatures, -multiplierSum, bcScaledMean.value, 1, + gradientSumArray, 1) + + // update the intercept part of gradientSumArray + gradientSumArray(numFeatures) += multiplierSum + } + + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala index de64440843..0683cec628 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala @@ -203,7 +203,7 @@ private[ml] class MultinomialLogisticBlockAggregator( } if (fitIntercept) { - BLAS.getBLAS(numClasses).daxpy(numClasses, 1.0, multiplierSum, 0, 1, + BLAS.javaBLAS.daxpy(numClasses, 1.0, multiplierSum, 0, 1, gradientSumArray, numClasses * numFeatures, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index df64de4b10..837883e53d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a2c376d80e..d2cfedcc33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -22,7 +22,6 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.google.common.collect.{Ordering => GuavaOrdering} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ @@ -34,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Kryo.KRYO_SERIALIZER_MAX_BUFFER_SIZE +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ @@ -401,18 +401,18 @@ class Word2Vec extends Serializable with Logging { val inner = bcVocab.value(word).point(d) val l2 = inner * vectorSize // Propagate hidden -> output - var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) + var f = BLAS.nativeBLAS.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + BLAS.nativeBLAS.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + BLAS.nativeBLAS.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) syn1Modify(inner) += 1 } d += 1 } - blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) syn0Modify(lastWord) += 1 } } @@ -448,10 +448,10 @@ class Word2Vec extends Serializable with Logging { (id, (vec, 1)) } }.reduceByKey { (vc1, vc2) => - blas.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1) + BLAS.nativeBLAS.saxpy(vectorSize, 1.0f, vc2._1, 1, vc1._1, 1) (vc1._1, vc1._2 + vc2._2) }.map { case (id, (vec, count)) => - blas.sscal(vectorSize, 1.0f / count, vec, 1) + BLAS.nativeBLAS.sscal(vectorSize, 1.0f / count, vec, 1) (id, vec) }.collect() var i = 0 @@ -511,7 +511,7 @@ class Word2VecModel private[spark] ( private lazy val wordVecInvNorms: Array[Float] = { val size = vectorSize Array.tabulate(numWords) { i => - val norm = blas.snrm2(size, wordVectors, i * size, 1) + val norm = BLAS.nativeBLAS.snrm2(size, wordVectors, i * size, 1) if (norm != 0) 1 / norm else 0.0F } } @@ -587,7 +587,7 @@ class Word2VecModel private[spark] ( val localVectorSize = vectorSize val floatVec = vector.map(_.toFloat) - val vecNorm = blas.snrm2(localVectorSize, floatVec, 1) + val vecNorm = BLAS.nativeBLAS.snrm2(localVectorSize, floatVec, 1) val localWordList = wordList val localNumWords = numWords @@ -597,11 +597,11 @@ class Word2VecModel private[spark] ( .take(num) .toArray } else { - // Normalize input vector before blas.sgemv to avoid Inf value - blas.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1) + // Normalize input vector before BLAS.nativeBLAS.sgemv to avoid Inf value + BLAS.nativeBLAS.sscal(localVectorSize, 1 / vecNorm, floatVec, 0, 1) val cosineVec = Array.ofDim[Float](localNumWords) - blas.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize, + BLAS.nativeBLAS.sgemv("T", localVectorSize, localNumWords, 1.0F, wordVectors, localVectorSize, floatVec, 1, 0.0F, cosineVec, 1) val localWordVecInvNorms = wordVecInvNorms diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala new file mode 100644 index 0000000000..fb0f6ddd47 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/ARPACK.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import dev.ludovic.netlib.{ARPACK => NetlibARPACK, + JavaARPACK => NetlibJavaARPACK, + NativeARPACK => NetlibNativeARPACK} + +/** + * ARPACK routines for MLlib's vectors and matrices. + */ +private[spark] object ARPACK extends Serializable { + + @transient private var _javaARPACK: NetlibARPACK = _ + @transient private var _nativeARPACK: NetlibARPACK = _ + + private[spark] def javaARPACK: NetlibARPACK = { + if (_javaARPACK == null) { + _javaARPACK = NetlibJavaARPACK.getInstance + } + _javaARPACK + } + + private[spark] def nativeARPACK: NetlibARPACK = { + if (_nativeARPACK == null) { + _nativeARPACK = + try { NetlibNativeARPACK.getInstance } catch { case _: Throwable => javaARPACK } + } + _nativeARPACK + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index bd60364326..e38cfe4e18 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -17,8 +17,9 @@ package org.apache.spark.mllib.linalg -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} -import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} +import dev.ludovic.netlib.{BLAS => NetlibBLAS, + JavaBLAS => NetlibJavaBLAS, + NativeBLAS => NetlibNativeBLAS} import org.apache.spark.internal.Logging @@ -27,21 +28,30 @@ import org.apache.spark.internal.Logging */ private[spark] object BLAS extends Serializable with Logging { - @transient private var _f2jBLAS: NetlibBLAS = _ + @transient private var _javaBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ private val nativeL1Threshold: Int = 256 - // For level-1 function dspmv, use f2jBLAS for better performance. - private[mllib] def f2jBLAS: NetlibBLAS = { - if (_f2jBLAS == null) { - _f2jBLAS = new F2jBLAS + // For level-1 function dspmv, use javaBLAS for better performance. + private[spark] def javaBLAS: NetlibBLAS = { + if (_javaBLAS == null) { + _javaBLAS = NetlibJavaBLAS.getInstance } - _f2jBLAS + _javaBLAS } - private[mllib] def getBLAS(vectorSize: Int): NetlibBLAS = { + // For level-3 routines, we use the native BLAS. + private[spark] def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = + try { NetlibNativeBLAS.getInstance } catch { case _: Throwable => javaBLAS } + } + _nativeBLAS + } + + private[spark] def getBLAS(vectorSize: Int): NetlibBLAS = { if (vectorSize < nativeL1Threshold) { - f2jBLAS + javaBLAS } else { nativeBLAS } @@ -237,14 +247,6 @@ private[spark] object BLAS extends Serializable with Logging { } } - // For level-3 routines, we use the native BLAS. - private[mllib] def nativeBLAS: NetlibBLAS = { - if (_nativeBLAS == null) { - _nativeBLAS = NativeBLAS - } - _nativeBLAS - } - /** * Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR. * @@ -263,7 +265,7 @@ private[spark] object BLAS extends Serializable with Logging { val n = v.size v match { case DenseVector(values) => - NativeBLAS.dspr("U", n, alpha, values, 1, U) + nativeBLAS.dspr("U", n, alpha, values, 1, U) case SparseVector(size, indices, values) => val nnz = indices.length var colStartIdx = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala index 68771f1afb..f06ea9418f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/CholeskyDecomposition.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.linalg -import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW import org.apache.spark.ml.optim.SingularMatrixException @@ -37,7 +36,7 @@ private[spark] object CholeskyDecomposition { def solve(A: Array[Double], bx: Array[Double]): Array[Double] = { val k = bx.length val info = new intW(0) - lapack.dppsv("U", k, 1, A, bx, k, info) + LAPACK.nativeLAPACK.dppsv("U", k, 1, A, bx, k, info) checkReturnValue(info, "dppsv") bx } @@ -52,7 +51,7 @@ private[spark] object CholeskyDecomposition { */ def inverse(UAi: Array[Double], k: Int): Array[Double] = { val info = new intW(0) - lapack.dpptri("U", k, UAi, info) + LAPACK.nativeLAPACK.dpptri("U", k, UAi, info) checkReturnValue(info, "dpptri") UAi } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 4c71cd6496..2cbf5d09dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} -import com.github.fommil.netlib.ARPACK import org.netlib.util.{doubleW, intW} /** @@ -51,8 +50,6 @@ private[mllib] object EigenValueDecomposition { // TODO: remove this function and use eigs in breeze when switching breeze version require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n") - val arpack = ARPACK.getInstance() - // tolerance used in stopping criterion val tolW = new doubleW(tol) // number of desired eigenvalues, 0 < nev < n @@ -87,8 +84,8 @@ private[mllib] object EigenValueDecomposition { val ipntr = new Array[Int](11) // call ARPACK's reverse communication, first iteration with ido = 0 - arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, workd, - workl, workl.length, info) + ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, + v, n, iparam, ipntr, workd, workl, workl.length, info) val w = BDV(workd) @@ -105,8 +102,8 @@ private[mllib] object EigenValueDecomposition { val y = w.slice(outputOffset, outputOffset + n) y := mul(x) // call ARPACK's reverse communication - arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, - workd, workl, workl.length, info) + ARPACK.nativeARPACK.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, + v, n, iparam, ipntr, workd, workl, workl.length, info) } if (info.`val` != 0) { @@ -127,8 +124,8 @@ private[mllib] object EigenValueDecomposition { val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n) // call ARPACK's post-processing for eigenvectors - arpack.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, ncv, v, n, - iparam, ipntr, workd, workl, workl.length, info) + ARPACK.nativeARPACK.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, + ncv, v, n, iparam, ipntr, workd, workl, workl.length, info) // number of computed eigenvalues, might be smaller than k val computed = iparam(4) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala new file mode 100644 index 0000000000..4d25aed283 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/LAPACK.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import dev.ludovic.netlib.{JavaLAPACK => NetlibJavaLAPACK, + LAPACK => NetlibLAPACK, + NativeLAPACK => NetlibNativeLAPACK} + +/** + * LAPACK routines for MLlib's vectors and matrices. + */ +private[spark] object LAPACK extends Serializable { + + @transient private var _javaLAPACK: NetlibLAPACK = _ + @transient private var _nativeLAPACK: NetlibLAPACK = _ + + private[spark] def javaLAPACK: NetlibLAPACK = { + if (_javaLAPACK == null) { + _javaLAPACK = NetlibJavaLAPACK.getInstance + } + _javaLAPACK + } + + private[spark] def nativeLAPACK: NetlibLAPACK = { + if (_nativeLAPACK == null) { + _nativeLAPACK = + try { NetlibNativeLAPACK.getInstance } catch { case _: Throwable => javaLAPACK } + } + _nativeLAPACK + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 57edc96511..e4f64b4e34 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, Has import scala.language.implicitConversions import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} @@ -427,7 +426,7 @@ class DenseMatrix @Since("1.3.0") ( if (isTransposed) { Iterator.tabulate(numCols) { j => val col = new Array[Double](numRows) - blas.dcopy(numRows, values, j, numCols, col, 0, 1) + BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1) new DenseVector(col) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index 86632ae335..e070d605b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.optimization import java.{util => ju} -import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.ml.linalg.BLAS /** * Object used to solve nonnegative least squares problems using a modified @@ -75,10 +75,10 @@ private[spark] object NNLS { // find the optimal unconstrained step def steplen(dir: Array[Double], res: Array[Double]): Double = { - val top = blas.ddot(n, dir, 1, res, 1) - blas.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) + val top = BLAS.nativeBLAS.ddot(n, dir, 1, res, 1) + BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, dir, 1, 0.0, scratch, 1) // Push the denominator upward very slightly to avoid infinities and silliness - top / (blas.ddot(n, scratch, 1, dir, 1) + 1e-20) + top / (BLAS.nativeBLAS.ddot(n, scratch, 1, dir, 1) + 1e-20) } // stopping condition @@ -103,9 +103,9 @@ private[spark] object NNLS { var i = 0 while (iterno < iterMax) { // find the residual - blas.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) - blas.daxpy(n, -1.0, atb, 1, res, 1) - blas.dcopy(n, res, 1, grad, 1) + BLAS.nativeBLAS.dgemv("N", n, n, 1.0, ata, n, x, 1, 0.0, res, 1) + BLAS.nativeBLAS.daxpy(n, -1.0, atb, 1, res, 1) + BLAS.nativeBLAS.dcopy(n, res, 1, grad, 1) // project the gradient i = 0 @@ -115,28 +115,28 @@ private[spark] object NNLS { } i = i + 1 } - val ngrad = blas.ddot(n, grad, 1, grad, 1) + val ngrad = BLAS.nativeBLAS.ddot(n, grad, 1, grad, 1) - blas.dcopy(n, grad, 1, dir, 1) + BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1) // use a CG direction under certain conditions var step = steplen(grad, res) var ndir = 0.0 - val nx = blas.ddot(n, x, 1, x, 1) + val nx = BLAS.nativeBLAS.ddot(n, x, 1, x, 1) if (iterno > lastWall + 1) { val alpha = ngrad / lastNorm - blas.daxpy(n, alpha, lastDir, 1, dir, 1) + BLAS.nativeBLAS.daxpy(n, alpha, lastDir, 1, dir, 1) val dstep = steplen(dir, res) - ndir = blas.ddot(n, dir, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) if (stop(dstep, ndir, nx)) { // reject the CG step if it could lead to premature termination - blas.dcopy(n, grad, 1, dir, 1) - ndir = blas.ddot(n, dir, 1, dir, 1) + BLAS.nativeBLAS.dcopy(n, grad, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) } else { step = dstep } } else { - ndir = blas.ddot(n, dir, 1, dir, 1) + ndir = BLAS.nativeBLAS.ddot(n, dir, 1, dir, 1) } // terminate? @@ -166,7 +166,7 @@ private[spark] object NNLS { } iterno = iterno + 1 - blas.dcopy(n, dir, 1, lastDir, 1) + BLAS.nativeBLAS.dcopy(n, dir, 1, lastDir, 1) lastNorm = ngrad } x.clone diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index b1be5225ce..3276513213 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -21,7 +21,6 @@ import java.io.IOException import java.lang.{Integer => JavaInteger} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.google.common.collect.{Ordering => GuavaOrdering} import org.apache.hadoop.fs.Path import org.json4s._ @@ -85,7 +84,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( val userVector = userFeatureSeq.head val productVector = productFeatureSeq.head - blas.ddot(rank, userVector, 1, productVector, 1) + BLAS.nativeBLAS.ddot(rank, userVector, 1, productVector, 1) } /** @@ -136,7 +135,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } users.join(productFeatures).map { case (product, ((user, uFeatures), pFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) } } else { val products = productFeatures.join(usersProducts.map(_.swap)).map { @@ -144,7 +143,7 @@ class MatrixFactorizationModel @Since("0.8.0") ( } products.join(userFeatures).map { case (user, ((product, pFeatures), uFeatures)) => - Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) + Rating(user, product, BLAS.nativeBLAS.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1)) } } } @@ -263,7 +262,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { recommendableFeatures: RDD[(Int, Array[Double])], num: Int): Array[(Int, Double)] = { val scored = recommendableFeatures.map { case (id, features) => - (id, blas.ddot(features.length, recommendToFeatures, 1, features, 1)) + (id, BLAS.nativeBLAS.ddot(features.length, recommendToFeatures, 1, features, 1)) } scored.top(num)(Ordering.by(_._2)) } @@ -320,7 +319,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { Iterator.range(0, m).flatMap { i => // scores = i-th vec in srcMat * dstMat - BLAS.f2jBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank, + BLAS.javaBLAS.dgemv("T", rank, n, 1.0F, dstMat, 0, rank, srcMat, i * rank, 1, 0.0F, scores, 0, 1) val srcId = srcIds(i) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index f253963270..f0236f0528 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.stat -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.rdd.RDD /** @@ -99,10 +98,10 @@ class KernelDensity extends Serializable { (x._1, x._2 + 1) }, (x, y) => { - blas.daxpy(n, 1.0, y._1, 1, x._1, 1) + BLAS.nativeBLAS.daxpy(n, 1.0, y._1, 1, x._1, 1) (x._1, x._2 + y._2) }) - blas.dscal(n, 1.0 / count, densities, 1) + BLAS.nativeBLAS.dscal(n, 1.0 / count, densities, 1) densities } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index c5069277fa..1f879a4d9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree.model import scala.collection.mutable -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,6 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo @@ -280,7 +280,7 @@ private[tree] sealed class TreeEnsembleModel( */ private def predictBySumming(features: Vector): Double = { val treePredictions = trees.map(_.predict(features)) - blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + BLAS.nativeBLAS.ddot(numTrees, treePredictions, 1, treeWeights, 1) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index 9fffa508af..0f99cef665 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -19,10 +19,9 @@ package org.apache.spark.mllib.util import scala.util.Random -import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.BLAS import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -61,7 +60,8 @@ object SVMDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextDouble() * 2.0 - 1.0 } - val yD = blas.ddot(trueWeights.length, x, 1, trueWeights, 1) + rnd.nextGaussian() * 0.1 + val yD = BLAS.nativeBLAS.ddot(trueWeights.length, x, 1, trueWeights, 1) + + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 LabeledPoint(y, Vectors.dense(x)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index d8b9c6a606..d18a950a01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -23,9 +23,8 @@ import breeze.linalg.{DenseVector => BDV} import org.scalatest.Assertions._ import org.apache.spark.ml.classification.LinearSVCSuite._ -import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -176,28 +175,13 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { assert(model2.intercept !== 0.0) } - test("sparse coefficients in HingeAggregator") { - val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) - val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) - val agg = new HingeAggregator(bcFeaturesStd, true)(bcCoefficients) - val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") { - intercept[IllegalArgumentException] { - agg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) - } - } - assert(thrown.getMessage.contains("coefficients only supports dense")) - - bcCoefficients.destroy() - bcFeaturesStd.destroy() - } - test("linearSVC with sample weights") { def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = { - assert(m1.coefficients ~== m2.coefficients absTol 0.05) + assert(m1.coefficients ~== m2.coefficients relTol 0.05) assert(m1.intercept ~== m2.intercept absTol 0.05) } - val estimator = new LinearSVC().setRegParam(0.01).setTol(0.01) + val estimator = new LinearSVC().setRegParam(0.01).setTol(0.001) val dataset = smallBinaryDataset MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( dataset.as[LabeledPoint], estimator, modelEquals) @@ -237,7 +221,7 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { val model1 = trainer1.fit(binaryDataset) /* - Use the following R code to load the data and train the model using glmnet package. + Use the following R code to load the data and train the model using e1071 package. library(e1071) data <- read.csv("path/target/tmp/LinearSVC/binaryDataset/part-00000", header=FALSE) @@ -257,8 +241,8 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { */ val coefficientsR = Vectors.dense(7.310338, 14.89741, 22.21005, 29.83508) val interceptR = 7.440177 - assert(model1.intercept ~== interceptR relTol 1E-2) - assert(model1.coefficients ~== coefficientsR relTol 1E-2) + assert(model1.intercept ~== interceptR relTol 1E-3) + assert(model1.coefficients ~== coefficientsR relTol 5E-3) /* Use the following python code to load the data and train the model using scikit-learn package. @@ -280,8 +264,8 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { val coefficientsSK = Vectors.dense(7.24690165, 14.77029087, 21.99924004, 29.5575729) val interceptSK = 7.36947518 - assert(model1.intercept ~== interceptSK relTol 1E-3) - assert(model1.coefficients ~== coefficientsSK relTol 4E-3) + assert(model1.intercept ~== interceptSK relTol 1E-2) + assert(model1.coefficients ~== coefficientsSK relTol 1E-2) } test("summary and training summary") { @@ -379,8 +363,8 @@ object LinearSVCSuite { } def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { - assert(model1.intercept == model2.intercept) - assert(model1.coefficients.equals(model2.coefficients)) + assert(model1.intercept ~== model2.intercept relTol 1e-9) + assert(model1.coefficients ~== model2.coefficients relTol 1e-9) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala new file mode 100644 index 0000000000..029911adb4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeBlockAggregatorSuite.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.{Instance, InstanceBlock} +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.stat.Summarizer +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class HingeBlockAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantFeatureFiltered: Array[Instance] = _ + @transient var scaledInstances: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(0.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(1.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantFeatureFiltered = Array( + Instance(0.0, 0.1, Vectors.dense(2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0)), + Instance(1.0, 0.3, Vectors.dense(0.5)) + ) + scaledInstances = standardize(instances) + } + + + /** Get summary statistics for some data and create a new HingeBlockAggregator. */ + private def getNewAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): HingeBlockAggregator = { + val (featuresSummarizer, _) = + Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + val featuresStd = featuresSummarizer.std.toArray + val featuresMean = featuresSummarizer.mean.toArray + val inverseStd = featuresStd.map(std => if (std != 0) 1.0 / std else 0.0) + val scaledMean = inverseStd.zip(featuresMean).map(t => t._1 * t._2) + val bcInverseStd = sc.broadcast(inverseStd) + val bcScaledMean = sc.broadcast(scaledMean) + val bcCoefficients = sc.broadcast(coefficients) + new HingeBlockAggregator(bcInverseStd, bcScaledMean, fitIntercept)(bcCoefficients) + } + + test("sparse coefficients") { + val bcInverseStd = sc.broadcast(Array(1.0)) + val bcScaledMean = sc.broadcast(Array(2.0)) + val bcCoefficients = sc.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val binaryAgg = new HingeBlockAggregator(bcInverseStd, bcScaledMean, + fitIntercept = false)(bcCoefficients) + val block = InstanceBlock.fromInstances(Seq(Instance(1.0, 1.0, Vectors.dense(1.0)))) + val thrownBinary = withClue("aggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + binaryAgg.add(block) + } + } + assert(thrownBinary.getMessage.contains("coefficients only supports dense")) + } + + test("aggregator add method input size") { + val coefArray = Array(1.0, 2.0) + val interceptValue = 4.0 + val agg = getNewAggregator(instances, Vectors.dense(coefArray :+ interceptValue), + fitIntercept = true) + val block = InstanceBlock.fromInstances(Seq(Instance(1.0, 1.0, Vectors.dense(2.0)))) + withClue("BinaryLogisticBlockAggregator features dimension must match coefficients dimension") { + intercept[IllegalArgumentException] { + agg.add(block) + } + } + } + + test("negative weight") { + val coefArray = Array(1.0, 2.0) + val interceptValue = 4.0 + val agg = getNewAggregator(instances, Vectors.dense(coefArray :+ interceptValue), + fitIntercept = true) + val block = InstanceBlock.fromInstances(Seq(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))) + withClue("BinaryLogisticBlockAggregator does not support negative instance weights") { + intercept[IllegalArgumentException] { + agg.add(block) + } + } + } + + test("check sizes") { + val rng = new scala.util.Random + val numFeatures = instances.head.features.size + val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble)) + val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble)) + val block = InstanceBlock.fromInstances(instances) + + val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true) + aggIntercept.add(block) + assert(aggIntercept.gradient.size === numFeatures + 1) + + val aggNoIntercept = getNewAggregator(instances, coefWithoutIntercept, fitIntercept = false) + aggNoIntercept.add(block) + assert(aggNoIntercept.gradient.size === numFeatures) + } + + test("check correctness: fitIntercept = false") { + val coefVec = Vectors.dense(1.0, 2.0) + val numFeatures = instances.head.features.size + val (featuresSummarizer, _) = + Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + val featuresStd = featuresSummarizer.std + val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) + val weightSum = instances.map(_.weight).sum + + // compute the loss and the gradients + var lossSum = 0.0 + val gradientCoef = Array.ofDim[Double](numFeatures) + instances.foreach { case Instance(l, w, f) => + val margin = BLAS.dot(stdCoefVec, f) + val labelScaled = 2 * l - 1.0 + if (1.0 > labelScaled * margin) { + lossSum += (1.0 - labelScaled * margin) * w + gradientCoef.indices.foreach { i => + gradientCoef(i) += f(i) * -(2 * l - 1.0) * w / featuresStd(i) + } + } + } + val loss = lossSum / weightSum + val gradient = Vectors.dense(gradientCoef.map(_ / weightSum)) + + Seq(1, 2, 4).foreach { blockSize => + val blocks1 = scaledInstances + .grouped(blockSize) + .map(seq => InstanceBlock.fromInstances(seq)) + .toArray + val blocks2 = blocks1.map { block => + new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) + } + + Seq(blocks1, blocks2).foreach { blocks => + val agg = getNewAggregator(instances, coefVec, fitIntercept = false) + blocks.foreach(agg.add) + assert(agg.loss ~== loss relTol 1e-9) + assert(agg.gradient ~== gradient relTol 1e-9) + } + } + } + + test("check correctness: fitIntercept = true") { + val coefVec = Vectors.dense(1.0, 2.0) + val interceptValue = 1.0 + val numFeatures = instances.head.features.size + val (featuresSummarizer, _) = + Summarizer.getClassificationSummarizers(sc.parallelize(instances)) + val featuresStd = featuresSummarizer.std + val featuresMean = featuresSummarizer.mean + val stdCoefVec = Vectors.dense(Array.tabulate(coefVec.size)(i => coefVec(i) / featuresStd(i))) + val weightSum = instances.map(_.weight).sum + + // compute the loss and the gradients + var lossSum = 0.0 + val gradientCoef = Array.ofDim[Double](numFeatures) + var gradientIntercept = 0.0 + instances.foreach { case Instance(l, w, f) => + val centered = f.toDense.copy + BLAS.axpy(-1.0, featuresMean, centered) + val margin = BLAS.dot(stdCoefVec, centered) + interceptValue + val labelScaled = 2 * l - 1.0 + if (1.0 > labelScaled * margin) { + lossSum += (1.0 - labelScaled * margin) * w + gradientCoef.indices.foreach { i => + gradientCoef(i) += (f(i) - featuresMean(i)) * -(2 * l - 1.0) * w / featuresStd(i) + } + gradientIntercept += -(2 * l - 1.0) * w + } + } + val loss = lossSum / weightSum + val gradient = Vectors.dense((gradientCoef :+ gradientIntercept).map(_ / weightSum)) + + Seq(1, 2, 4).foreach { blockSize => + val blocks1 = scaledInstances + .grouped(blockSize) + .map(seq => InstanceBlock.fromInstances(seq)) + .toArray + val blocks2 = blocks1.map { block => + new InstanceBlock(block.labels, block.weights, block.matrix.toSparseRowMajor) + } + + Seq(blocks1, blocks2).foreach { blocks => + val agg = getNewAggregator(instances, Vectors.dense(coefVec.toArray :+ interceptValue), + fitIntercept = true) + blocks.foreach(agg.add) + assert(agg.loss ~== loss relTol 1e-9) + assert(agg.gradient ~== gradient relTol 1e-9) + } + } + } + + test("check with zero standard deviation") { + val coefArray = Array(1.0, 2.0) + val coefArrayFiltered = Array(2.0) + val interceptValue = 1.0 + + Seq(false, true).foreach { fitIntercept => + val coefVec = if (fitIntercept) { + Vectors.dense(coefArray :+ interceptValue) + } else { + Vectors.dense(coefArray) + } + val aggConstantFeature = getNewAggregator(instancesConstantFeature, + coefVec, fitIntercept = fitIntercept) + aggConstantFeature + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeature))) + val grad = aggConstantFeature.gradient + + val coefVecFiltered = if (fitIntercept) { + Vectors.dense(coefArrayFiltered :+ interceptValue) + } else { + Vectors.dense(coefArrayFiltered) + } + val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered, + coefVecFiltered, fitIntercept = fitIntercept) + aggConstantFeatureFiltered + .add(InstanceBlock.fromInstances(standardize(instancesConstantFeatureFiltered))) + val gradFiltered = aggConstantFeatureFiltered.gradient + + // constant features should not affect gradient + assert(aggConstantFeature.loss ~== aggConstantFeatureFiltered.loss relTol 1e-9) + assert(grad(0) === 0) + assert(grad(1) ~== gradFiltered(0) relTol 1e-9) + if (fitIntercept) { + assert(grad.toArray.last ~== gradFiltered.toArray.last relTol 1e-9) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 12ab2ac3cc..91d1e9a447 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { test("nativeL1Threshold") { - assert(getBLAS(128) == BLAS.f2jBLAS) + assert(getBLAS(128) == BLAS.javaBLAS) assert(getBLAS(256) == BLAS.nativeBLAS) assert(getBLAS(512) == BLAS.nativeBLAS) } diff --git a/pom.xml b/pom.xml index 22d794ccde..9402fd4528 100644 --- a/pom.xml +++ b/pom.xml @@ -133,12 +133,12 @@ 2.3 - 2.6.0 + 2.8.0 10.14.2.0 1.12.0 1.6.7 - 9.4.39.v20210325 + 9.4.40.v20210413 4.0.3 0.9.5 2.4.0 @@ -172,6 +172,7 @@ 2.12.2 1.1.8.2 1.1.2 + 1.3.2 1.15 1.20 2.8.0 @@ -2455,6 +2456,21 @@ commons-cli ${commons-cli.version} + + dev.ludovic.netlib + blas + ${netlib.ludovic.dev.version} + + + dev.ludovic.netlib + lapack + ${netlib.ludovic.dev.version} + + + dev.ludovic.netlib + arpack + ${netlib.ludovic.dev.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 54ac3c19fa..906065ca09 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -294,7 +294,7 @@ object SparkBuild extends PomBuild { javaOptions ++= { val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) var major = versionParts(0).toInt - if (major >= 16) Seq("--add-modules=jdk.incubator.vector") else Seq.empty + if (major >= 16) Seq("--add-modules=jdk.incubator.vector,jdk.incubator.foreign", "-Dforeign.restricted=warn") else Seq.empty }, (Compile / doc / javacOptions) ++= { @@ -414,6 +414,10 @@ object SparkBuild extends PomBuild { enable(YARN.settings)(yarn) + if (profiles.contains("sparkr")) { + enable(SparkR.settings)(core) + } + /** * Adds the ability to run the spark shell directly from SBT without building an assembly * jar. @@ -888,6 +892,25 @@ object PySparkAssembly { } +object SparkR { + import scala.sys.process.Process + + val buildRPackage = taskKey[Unit]("Build the R package") + lazy val settings = Seq( + buildRPackage := { + val command = baseDirectory.value / ".." / "R" / "install-dev.sh" + Process(command.toString).!! + }, + (Compile / compile) := (Def.taskDyn { + val c = (Compile / compile).value + Def.task { + (Compile / buildRPackage).value + c + } + }).value + ) +} + object Unidoc { import BuildCommons._ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 17994ed5e3..620760905a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -571,9 +571,9 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl >>> model.getMaxBlockSizeInMB() 0.0 >>> model.coefficients - DenseVector([0.0, -0.2792, -0.1833]) + DenseVector([0.0, -1.0319, -0.5159]) >>> model.intercept - 1.0206118982229047 + 2.579645978780695 >>> model.numClasses 2 >>> model.numFeatures @@ -582,12 +582,12 @@ class LinearSVC(_JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadabl >>> model.predict(test0.head().features) 1.0 >>> model.predictRaw(test0.head().features) - DenseVector([-1.4831, 1.4831]) + DenseVector([-4.1274, 4.1274]) >>> result = model.transform(test0).head() >>> result.newPrediction 1.0 >>> result.rawPrediction - DenseVector([-1.4831, 1.4831]) + DenseVector([-4.1274, 4.1274]) >>> svm_path = temp_path + "/svm" >>> svm.save(svm_path) >>> svm2 = LinearSVC.load(svm_path) diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index fb245a3d05..1eadbd6942 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -71,7 +71,8 @@ def vector_to_array(col, dtype="float64"): def array_to_vector(col): """ - Converts a column of array of numeric type into a column of dense vectors in MLlib + Converts a column of array of numeric type into a column of pyspark.ml.linalg.DenseVector + instances .. versionadded:: 3.1.0 @@ -83,7 +84,7 @@ def array_to_vector(col): Returns ------- :py:class:`pyspark.sql.Column` - The converted column of MLlib dense vectors. + The converted column of dense vectors. Examples -------- diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 28c4499f77..5bc1801a0c 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -260,9 +260,9 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable): >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, newPrediction=0.692910...) + Row(user=0, item=2, newPrediction=0.69291...) >>> predictions[1] - Row(user=1, item=0, newPrediction=3.473569...) + Row(user=1, item=0, newPrediction=3.47356...) >>> predictions[2] Row(user=2, item=0, newPrediction=-0.899198...) >>> user_recs = model.recommendForAllUsers(3) diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 7dafdcb3d6..5b31c871fb 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -223,12 +223,12 @@ def test_linear_svc_summary(self): self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) print(s.weightedTruePositiveRate) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.5, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.5, 2) - self.assertAlmostEqual(s.weightedRecall, 0.5, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.25, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.3333333333333333, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.3333333333333333, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index aeb603c5e7..a0eb243a6c 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -22,7 +22,6 @@ import numpy as np import pandas as pd -import pyspark import pyspark.pandas as ps from pyspark.pandas.exceptions import PandasNotImplementedError @@ -32,10 +31,10 @@ MissingPandasLikeIndex, MissingPandasLikeMultiIndex, ) -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED -class IndexesTest(ReusedSQLTestCase, TestUtils): +class IndexesTest(PandasOnSparkTestCase, TestUtils): @property def pdf(self): return pd.DataFrame( @@ -280,12 +279,7 @@ def test_multi_index_names(self): pidx.names = ["renamed_number", None] kidx.names = ["renamed_number", None] self.assertEqual(kidx.names, pidx.names) - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - # PySpark < 2.4 does not support struct type with arrow enabled. - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kidx, pidx) - else: - self.assert_eq(kidx, pidx) + self.assert_eq(kidx, pidx) with self.assertRaises(PandasNotImplementedError): kidx.name @@ -1401,11 +1395,7 @@ def test_asof(self): self.assert_eq(kidx.asof("2014-01-01"), pidx.asof("2014-01-01")) self.assert_eq(kidx.asof("2014-01-02"), pidx.asof("2014-01-02")) - if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"): - self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02"))) - else: - # FIXME: self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02"))) - pass + self.assert_eq(repr(kidx.asof("1999-01-02")), repr(pidx.asof("1999-01-02"))) # Decreasing values pidx = pd.Index(["2014-01-03", "2014-01-02", "2013-12-31"]) @@ -1427,11 +1417,7 @@ def test_asof(self): self.assert_eq(kidx.asof("2014-01-01"), pd.Timestamp("2014-01-02 00:00:00")) self.assert_eq(kidx.asof("2014-01-02"), pd.Timestamp("2014-01-02 00:00:00")) self.assert_eq(kidx.asof("1999-01-02"), pd.Timestamp("2013-12-31 00:00:00")) - if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"): - self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT)) - else: - # FIXME: self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT)) - pass + self.assert_eq(repr(kidx.asof("2015-01-02")), repr(pd.NaT)) # Not increasing, neither decreasing (ValueError) kidx = ps.Index(["2013-12-31", "2015-01-02", "2014-01-03"]) @@ -2249,13 +2235,7 @@ def test_to_list(self): kmidx = ps.from_pandas(pmidx) self.assert_eq(kidx.tolist(), pidx.tolist()) - - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - # PySpark < 2.4 does not support struct type with arrow enabled. - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kmidx.tolist(), pmidx.tolist()) - else: - self.assert_eq(kidx.tolist(), pidx.tolist()) + self.assert_eq(kmidx.tolist(), pmidx.tolist()) def test_index_ops(self): pidx = pd.Index([1, 2, 3, 4, 5]) diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index 0fe5eeb209..360e31863d 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -21,10 +21,10 @@ from pandas.api.types import CategoricalDtype import pyspark.pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class CategoricalIndexTest(ReusedSQLTestCase, TestUtils): +class CategoricalIndexTest(PandasOnSparkTestCase, TestUtils): def test_categorical_index(self): pidx = pd.CategoricalIndex([1, 2, 3]) kidx = ps.CategoricalIndex([1, 2, 3]) diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py b/python/pyspark/pandas/tests/indexes/test_datetime.py index 407565b46d..af511ed6c2 100644 --- a/python/pyspark/pandas/tests/indexes/test_datetime.py +++ b/python/pyspark/pandas/tests/indexes/test_datetime.py @@ -22,10 +22,10 @@ import pandas as pd import pyspark.pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class DatetimeIndexTest(ReusedSQLTestCase, TestUtils): +class DatetimeIndexTest(PandasOnSparkTestCase, TestUtils): @property def fixed_freqs(self): return [ diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot.py b/python/pyspark/pandas/tests/plot/test_frame_plot.py index 70822e1e32..b57acd4959 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot.py @@ -22,10 +22,10 @@ from pyspark.pandas.config import set_option, reset_option, option_context from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase from pyspark.pandas.exceptions import PandasNotImplementedError -from pyspark.pandas.testing.utils import ReusedSQLTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class DataFramePlotTest(ReusedSQLTestCase): +class DataFramePlotTest(PandasOnSparkTestCase): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py index 6e8f3c0256..5de5c90856 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py @@ -25,7 +25,12 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import have_matplotlib, ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import ( + have_matplotlib, + matplotlib_requirement_message, + PandasOnSparkTestCase, + TestUtils, +) if have_matplotlib: import matplotlib @@ -34,8 +39,8 @@ matplotlib.use("agg") -@unittest.skipIf(not have_matplotlib, "matplotlib is not installed.") -class DataFramePlotMatplotlibTest(ReusedSQLTestCase, TestUtils): +@unittest.skipIf(not have_matplotlib, matplotlib_requirement_message) +class DataFramePlotMatplotlibTest(PandasOnSparkTestCase, TestUtils): sample_ratio_default = None @classmethod diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index dca5a4307b..33d6bef2c8 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -24,7 +24,12 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils, have_plotly +from pyspark.testing.pandasutils import ( + have_plotly, + plotly_requirement_message, + PandasOnSparkTestCase, + TestUtils, +) from pyspark.pandas.utils import name_like_string if have_plotly: @@ -34,10 +39,10 @@ @unittest.skipIf( not have_plotly or LooseVersion(pd.__version__) < "1.0.0", - "plotly is not installed or pandas<1.0. pandas<1.0 does not support latest plotly " + plotly_requirement_message + " Or pandas<1.0; pandas<1.0 does not support latest plotly " "and/or 'plotting.backend' option.", ) -class DataFramePlotPlotlyTest(ReusedSQLTestCase, TestUtils): +class DataFramePlotPlotlyTest(PandasOnSparkTestCase, TestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py b/python/pyspark/pandas/tests/plot/test_series_plot.py index 4292c960a2..fbfda88648 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot.py @@ -22,7 +22,7 @@ from pyspark import pandas as ps from pyspark.pandas.plot import PandasOnSparkPlotAccessor, BoxPlotBase -from pyspark.pandas.testing.utils import have_plotly +from pyspark.testing.pandasutils import have_plotly, plotly_requirement_message class SeriesPlotTest(unittest.TestCase): @@ -36,7 +36,7 @@ def pdf1(self): def kdf1(self): return ps.from_pandas(self.pdf1) - @unittest.skipIf(not have_plotly, "plotly is not installed") + @unittest.skipIf(not have_plotly, plotly_requirement_message) def test_plot_backends(self): plot_backend = "plotly" diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py index 6bef5c9316..364a39bff8 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py @@ -25,7 +25,12 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import have_matplotlib, ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import ( + have_matplotlib, + matplotlib_requirement_message, + PandasOnSparkTestCase, + TestUtils, +) if have_matplotlib: import matplotlib @@ -34,8 +39,8 @@ matplotlib.use("agg") -@unittest.skipIf(not have_matplotlib, "matplotlib is not installed.") -class SeriesPlotMatplotlibTest(ReusedSQLTestCase, TestUtils): +@unittest.skipIf(not have_matplotlib, matplotlib_requirement_message) +class SeriesPlotMatplotlibTest(PandasOnSparkTestCase, TestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py index 5c0f2f7e89..2a14d373d2 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py @@ -24,8 +24,13 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import have_plotly, ReusedSQLTestCase, TestUtils from pyspark.pandas.utils import name_like_string +from pyspark.testing.pandasutils import ( + have_plotly, + plotly_requirement_message, + PandasOnSparkTestCase, + TestUtils, +) if have_plotly: from plotly import express @@ -34,10 +39,10 @@ @unittest.skipIf( not have_plotly or LooseVersion(pd.__version__) < "1.0.0", - "plotly is not installed or pandas<1.0. pandas<1.0 does not support latest plotly " + plotly_requirement_message + " Or pandas<1.0; pandas<1.0 does not support latest plotly " "and/or 'plotting.backend' option.", ) -class SeriesPlotPlotlyTest(ReusedSQLTestCase, TestUtils): +class SeriesPlotPlotlyTest(PandasOnSparkTestCase, TestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index 90e37ddbf6..28de94bbcb 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -22,10 +22,10 @@ from pandas.api.types import CategoricalDtype import pyspark.pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class CategoricalTest(ReusedSQLTestCase, TestUtils): +class CategoricalTest(PandasOnSparkTestCase, TestUtils): @property def pdf(self): return pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_config.py b/python/pyspark/pandas/tests/test_config.py index 1fb2cd344a..ba717a9712 100644 --- a/python/pyspark/pandas/tests/test_config.py +++ b/python/pyspark/pandas/tests/test_config.py @@ -18,10 +18,10 @@ from pyspark import pandas as ps from pyspark.pandas import config from pyspark.pandas.config import Option, DictWrapper -from pyspark.pandas.testing.utils import ReusedSQLTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class ConfigTest(ReusedSQLTestCase): +class ConfigTest(PandasOnSparkTestCase): def setUp(self): config._options_dict["test.config"] = Option(key="test.config", doc="", default="default") diff --git a/python/pyspark/pandas/tests/test_csv.py b/python/pyspark/pandas/tests/test_csv.py index 7d32d819b5..17b3060c92 100644 --- a/python/pyspark/pandas/tests/test_csv.py +++ b/python/pyspark/pandas/tests/test_csv.py @@ -24,14 +24,14 @@ import numpy as np from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils def normalize_text(s): return "\n".join(map(str.strip, s.strip().split("\n"))) -class CsvTest(ReusedSQLTestCase, TestUtils): +class CsvTest(PandasOnSparkTestCase, TestUtils): def setUp(self): self.tmp_dir = tempfile.mkdtemp(prefix=CsvTest.__name__) diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 6fa4933c25..d7cb3ab359 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -25,7 +25,6 @@ import numpy as np import pandas as pd from pandas.tseries.offsets import DateOffset -import pyspark from pyspark import StorageLevel from pyspark.ml.linalg import SparseVector from pyspark.sql import functions as F @@ -41,16 +40,17 @@ extension_float_dtypes_available, extension_object_dtypes_available, ) -from pyspark.pandas.testing.utils import ( +from pyspark.testing.pandasutils import ( have_tabulate, - ReusedSQLTestCase, - SQLTestUtils, + PandasOnSparkTestCase, SPARK_CONF_ARROW_ENABLED, + tabulate_requirement_message, ) +from pyspark.testing.sqlutils import SQLTestUtils from pyspark.pandas.utils import name_like_string -class DataFrameTest(ReusedSQLTestCase, SQLTestUtils): +class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): @property def pdf(self): return pd.DataFrame( @@ -565,11 +565,7 @@ def test_empty_dataframe(self): pdf = pd.DataFrame({"a": pd.Series([], dtype="i1"), "b": pd.Series([], dtype="str")}) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(kdf, pdf) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kdf, pdf) + self.assert_eq(kdf, pdf) with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): kdf = ps.from_pandas(pdf) @@ -601,11 +597,7 @@ def test_all_null_dataframe(self): ) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(kdf, pdf) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kdf, pdf) + self.assert_eq(kdf, pdf) with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): kdf = ps.from_pandas(pdf) @@ -2990,10 +2982,6 @@ def test_pivot_table_and_index(self): self.assert_eq(ktable.index, ptable.index) self.assert_eq(repr(ktable.index), repr(ptable.index)) - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("2.4"), - "stack won't work properly with PySpark<2.4", - ) def test_stack(self): pdf_single_level_cols = pd.DataFrame( [[0, 1], [2, 3]], index=["cat", "dog"], columns=["weight", "height"] @@ -3235,22 +3223,13 @@ def _test_cumprod(self, pdf, kdf): self.assert_eq(pdf.cumprod().sum(), kdf.cumprod().sum(), almost=True) def test_cumprod(self): - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - pdf = pd.DataFrame( - [[2.0, 1.0, 1], [5, None, 2], [1.0, -1.0, -3], [2.0, 0, 4], [4.0, 9.0, 5]], - columns=list("ABC"), - index=np.random.rand(5), - ) - kdf = ps.from_pandas(pdf) - self._test_cumprod(pdf, kdf) - else: - pdf = pd.DataFrame( - [[2, 1, 1], [5, 1, 2], [1, -1, -3], [2, 0, 4], [4, 9, 5]], - columns=list("ABC"), - index=np.random.rand(5), - ) - kdf = ps.from_pandas(pdf) - self._test_cumprod(pdf, kdf) + pdf = pd.DataFrame( + [[2.0, 1.0, 1], [5, None, 2], [1.0, -1.0, -3], [2.0, 0, 4], [4.0, 9.0, 5]], + columns=list("ABC"), + index=np.random.rand(5), + ) + kdf = ps.from_pandas(pdf) + self._test_cumprod(pdf, kdf) def test_cumprod_multiindex_columns(self): arrays = [np.array(["A", "A", "B", "B"]), np.array(["one", "two", "one", "two"])] @@ -4725,13 +4704,8 @@ def test_udt(self): sparse_vector = SparseVector(len(sparse_values), sparse_values) pdf = pd.DataFrame({"a": [sparse_vector], "b": [10]}) - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - kdf = ps.from_pandas(pdf) - self.assert_eq(kdf, pdf) - else: - kdf = ps.from_pandas(pdf) - self.assert_eq(kdf, pdf) + kdf = ps.from_pandas(pdf) + self.assert_eq(kdf, pdf) def test_eval(self): pdf = pd.DataFrame({"A": range(1, 6), "B": range(10, 0, -2)}) @@ -4767,7 +4741,7 @@ def test_eval(self): kdf.columns = columns self.assertRaises(ValueError, lambda: kdf.eval("x.a + y.b")) - @unittest.skipIf(not have_tabulate, "tabulate not installed") + @unittest.skipIf(not have_tabulate, tabulate_requirement_message) def test_to_markdown(self): pdf = pd.DataFrame(data={"animal_1": ["elk", "pig"], "animal_2": ["dog", "quetzal"]}) kdf = ps.from_pandas(pdf) @@ -5161,10 +5135,6 @@ def test_iteritems(self): self.assert_eq(p_name, k_name) self.assert_eq(p_items, k_items) - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("3.0"), - "tail won't work properly with PySpark<3.0", - ) def test_tail(self): pdf = pd.DataFrame({"x": range(1000)}) kdf = ps.from_pandas(pdf) @@ -5184,10 +5154,6 @@ def test_tail(self): with self.assertRaisesRegex(TypeError, "bad operand type for unary -: 'str'"): kdf.tail("10") - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("3.0"), - "last_valid_index won't work properly with PySpark<3.0", - ) def test_last_valid_index(self): pdf = pd.DataFrame( {"a": [1, 2, 3, None], "b": [1.0, 2.0, 3.0, None], "c": [100, 200, 400, None]}, diff --git a/python/pyspark/pandas/tests/test_dataframe_conversion.py b/python/pyspark/pandas/tests/test_dataframe_conversion.py index 8b64398634..92ddef6014 100644 --- a/python/pyspark/pandas/tests/test_dataframe_conversion.py +++ b/python/pyspark/pandas/tests/test_dataframe_conversion.py @@ -24,12 +24,13 @@ import numpy as np import pandas as pd -from pyspark import pandas as pp from distutils.version import LooseVersion -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils, TestUtils +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.sqlutils import SQLTestUtils -class DataFrameConversionTest(ReusedSQLTestCase, SQLTestUtils, TestUtils): +class DataFrameConversionTest(PandasOnSparkTestCase, SQLTestUtils, TestUtils): """Test cases for "small data" conversion and I/O.""" def setUp(self): @@ -44,7 +45,7 @@ def pdf(self): @property def kdf(self): - return pp.from_pandas(self.pdf) + return ps.from_pandas(self.pdf) @staticmethod def strip_all_whitespace(str): @@ -113,7 +114,7 @@ def test_to_excel(self): pdf = pd.DataFrame({"a": [1, None, 3], "b": ["one", "two", None]}, index=[0, 1, 3]) - kdf = pp.from_pandas(pdf) + kdf = ps.from_pandas(pdf) kdf.to_excel(koalas_location, na_rep="null") pdf.to_excel(pandas_location, na_rep="null") @@ -122,7 +123,7 @@ def test_to_excel(self): pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}, index=[0, 1, 3]) - kdf = pp.from_pandas(pdf) + kdf = ps.from_pandas(pdf) kdf.to_excel(koalas_location, float_format="%.1f") pdf.to_excel(pandas_location, float_format="%.1f") @@ -141,12 +142,12 @@ def test_to_excel(self): def test_to_json(self): pdf = self.pdf - kdf = pp.from_pandas(pdf) + kdf = ps.from_pandas(pdf) self.assert_eq(kdf.to_json(orient="records"), pdf.to_json(orient="records")) def test_to_json_negative(self): - kdf = pp.from_pandas(self.pdf) + kdf = ps.from_pandas(self.pdf) with self.assertRaises(NotImplementedError): kdf.to_json(orient="table") @@ -156,11 +157,11 @@ def test_to_json_negative(self): def test_read_json_negative(self): with self.assertRaises(NotImplementedError): - pp.read_json("invalid", lines=False) + ps.read_json("invalid", lines=False) def test_to_json_with_path(self): pdf = pd.DataFrame({"a": [1], "b": ["a"]}) - kdf = pp.DataFrame(pdf) + kdf = ps.DataFrame(pdf) kdf.to_json(self.tmp_dir, num_files=1) expected = pdf.to_json(orient="records") @@ -172,7 +173,7 @@ def test_to_json_with_path(self): def test_to_json_with_partition_cols(self): pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) - kdf = pp.DataFrame(pdf) + kdf = ps.DataFrame(pdf) kdf.to_json(self.tmp_dir, partition_cols="b", num_files=1) @@ -224,7 +225,7 @@ def test_to_records(self): if LooseVersion(pd.__version__) >= LooseVersion("0.24.0"): pdf = pd.DataFrame({"A": [1, 2], "B": [0.5, 0.75]}, index=["a", "b"]) - kdf = pp.from_pandas(pdf) + kdf = ps.from_pandas(pdf) self.assert_eq(kdf.to_records(), pdf.to_records()) self.assert_eq(kdf.to_records(index=False), pdf.to_records(index=False)) @@ -233,32 +234,32 @@ def test_to_records(self): def test_from_records(self): # Assert using a dict as input self.assert_eq( - pp.DataFrame.from_records({"A": [1, 2, 3]}), pd.DataFrame.from_records({"A": [1, 2, 3]}) + ps.DataFrame.from_records({"A": [1, 2, 3]}), pd.DataFrame.from_records({"A": [1, 2, 3]}) ) # Assert using a list of tuples as input self.assert_eq( - pp.DataFrame.from_records([(1, 2), (3, 4)]), pd.DataFrame.from_records([(1, 2), (3, 4)]) + ps.DataFrame.from_records([(1, 2), (3, 4)]), pd.DataFrame.from_records([(1, 2), (3, 4)]) ) # Assert using a NumPy array as input - self.assert_eq(pp.DataFrame.from_records(np.eye(3)), pd.DataFrame.from_records(np.eye(3))) + self.assert_eq(ps.DataFrame.from_records(np.eye(3)), pd.DataFrame.from_records(np.eye(3))) # Asserting using a custom index self.assert_eq( - pp.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]), + ps.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]), pd.DataFrame.from_records([(1, 2), (3, 4)], index=[2, 3]), ) # Assert excluding excluding column(s) self.assert_eq( - pp.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]), + ps.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]), pd.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, exclude=["B"]), ) # Assert limiting to certain column(s) self.assert_eq( - pp.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]), + ps.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]), pd.DataFrame.from_records({"A": [1, 2, 3], "B": [1, 2, 3]}, columns=["A"]), ) # Assert limiting to a number of rows self.assert_eq( - pp.DataFrame.from_records([(1, 2), (3, 4)], nrows=1), + ps.DataFrame.from_records([(1, 2), (3, 4)], nrows=1), pd.DataFrame.from_records([(1, 2), (3, 4)], nrows=1), ) diff --git a/python/pyspark/pandas/tests/test_dataframe_spark_io.py b/python/pyspark/pandas/tests/test_dataframe_spark_io.py index 818ce61dd8..f0982bd4e2 100644 --- a/python/pyspark/pandas/tests/test_dataframe_spark_io.py +++ b/python/pyspark/pandas/tests/test_dataframe_spark_io.py @@ -23,13 +23,12 @@ import numpy as np import pandas as pd import pyarrow as pa -import pyspark -from pyspark import pandas as pp -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class DataFrameSparkIOTest(ReusedSQLTestCase, TestUtils): +class DataFrameSparkIOTest(PandasOnSparkTestCase, TestUtils): """Test cases for big data I/O using Spark.""" @property @@ -60,7 +59,7 @@ def test_parquet_read(self): def check(columns, expected): if LooseVersion("0.21.1") <= LooseVersion(pd.__version__): expected = pd.read_parquet(tmp, columns=columns) - actual = pp.read_parquet(tmp, columns=columns) + actual = ps.read_parquet(tmp, columns=columns) self.assertPandasEqual(expected, actual.to_pandas()) check(None, data) @@ -82,24 +81,20 @@ def check(columns, expected): expected = pd.read_parquet(tmp) else: expected = data - actual = pp.read_parquet(tmp) + actual = ps.read_parquet(tmp) self.assertPandasEqual(expected, actual.to_pandas()) # When index columns are known pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) expected_idx = expected.set_index("bhello")[["f", "i32", "i64"]] - actual_idx = pp.read_parquet(tmp, index_col="bhello")[["f", "i32", "i64"]] + actual_idx = ps.read_parquet(tmp, index_col="bhello")[["f", "i32", "i64"]] self.assert_eq( actual_idx.sort_values(by="f").to_spark().toPandas(), expected_idx.sort_values(by="f").to_spark().toPandas(), ) - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("3.0.0"), - "The test only works with Spark>=3.0", - ) def test_parquet_read_with_pandas_metadata(self): with self.temp_dir() as tmp: expected1 = self.test_pdf @@ -107,32 +102,32 @@ def test_parquet_read_with_pandas_metadata(self): path1 = "{}/file1.parquet".format(tmp) expected1.to_parquet(path1) - self.assert_eq(pp.read_parquet(path1, pandas_metadata=True), expected1) + self.assert_eq(ps.read_parquet(path1, pandas_metadata=True), expected1) expected2 = expected1.reset_index() path2 = "{}/file2.parquet".format(tmp) expected2.to_parquet(path2) - self.assert_eq(pp.read_parquet(path2, pandas_metadata=True), expected2) + self.assert_eq(ps.read_parquet(path2, pandas_metadata=True), expected2) expected3 = expected2.set_index("index", append=True) path3 = "{}/file3.parquet".format(tmp) expected3.to_parquet(path3) - self.assert_eq(pp.read_parquet(path3, pandas_metadata=True), expected3) + self.assert_eq(ps.read_parquet(path3, pandas_metadata=True), expected3) def test_parquet_write(self): with self.temp_dir() as tmp: pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) # Write out partitioned by one column expected.to_parquet(tmp, mode="overwrite", partition_cols="i32") # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_parquet(tmp) + actual = ps.read_parquet(tmp) self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -144,7 +139,7 @@ def test_parquet_write(self): expected.to_parquet(tmp, mode="overwrite", partition_cols=["i32", "bhello"]) # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_parquet(tmp) + actual = ps.read_parquet(tmp) self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -155,13 +150,13 @@ def test_parquet_write(self): def test_table(self): with self.table("test_table"): pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) # Write out partitioned by one column expected.spark.to_table("test_table", mode="overwrite", partition_cols="i32") # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_table("test_table") + actual = ps.read_table("test_table") self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -173,7 +168,7 @@ def test_table(self): expected.to_table("test_table", mode="overwrite", partition_cols=["i32", "bhello"]) # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_table("test_table") + actual = ps.read_table("test_table") self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -183,21 +178,21 @@ def test_table(self): # When index columns are known expected_idx = expected.set_index("bhello")[["f", "i32", "i64"]] - actual_idx = pp.read_table("test_table", index_col="bhello")[["f", "i32", "i64"]] + actual_idx = ps.read_table("test_table", index_col="bhello")[["f", "i32", "i64"]] self.assert_eq( actual_idx.sort_values(by="f").to_spark().toPandas(), expected_idx.sort_values(by="f").to_spark().toPandas(), ) expected_idx = expected.set_index(["bhello"])[["f", "i32", "i64"]] - actual_idx = pp.read_table("test_table", index_col=["bhello"])[["f", "i32", "i64"]] + actual_idx = ps.read_table("test_table", index_col=["bhello"])[["f", "i32", "i64"]] self.assert_eq( actual_idx.sort_values(by="f").to_spark().toPandas(), expected_idx.sort_values(by="f").to_spark().toPandas(), ) expected_idx = expected.set_index(["i32", "bhello"])[["f", "i64"]] - actual_idx = pp.read_table("test_table", index_col=["i32", "bhello"])[["f", "i64"]] + actual_idx = ps.read_table("test_table", index_col=["i32", "bhello"])[["f", "i64"]] self.assert_eq( actual_idx.sort_values(by="f").to_spark().toPandas(), expected_idx.sort_values(by="f").to_spark().toPandas(), @@ -206,13 +201,13 @@ def test_table(self): def test_spark_io(self): with self.temp_dir() as tmp: pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) # Write out partitioned by one column expected.to_spark_io(tmp, format="json", mode="overwrite", partition_cols="i32") # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_spark_io(tmp, format="json") + actual = ps.read_spark_io(tmp, format="json") self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -226,7 +221,7 @@ def test_spark_io(self): ) # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_spark_io(path=tmp, format="json") + actual = ps.read_spark_io(path=tmp, format="json") self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -236,11 +231,11 @@ def test_spark_io(self): # When index columns are known pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) col_order = ["f", "i32", "i64"] expected_idx = expected.set_index("bhello")[col_order] - actual_idx = pp.read_spark_io(tmp, format="json", index_col="bhello")[col_order] + actual_idx = ps.read_spark_io(tmp, format="json", index_col="bhello")[col_order] self.assert_eq( actual_idx.sort_values(by="f").to_spark().toPandas(), expected_idx.sort_values(by="f").to_spark().toPandas(), @@ -253,45 +248,42 @@ def test_read_excel(self): path1 = "{}/file1.xlsx".format(tmp) self.test_pdf[["i32"]].to_excel(path1) - self.assert_eq(pp.read_excel(open(path1, "rb")), pd.read_excel(open(path1, "rb"))) + self.assert_eq(ps.read_excel(open(path1, "rb")), pd.read_excel(open(path1, "rb"))) self.assert_eq( - pp.read_excel(open(path1, "rb"), index_col=0), + ps.read_excel(open(path1, "rb"), index_col=0), pd.read_excel(open(path1, "rb"), index_col=0), ) self.assert_eq( - pp.read_excel(open(path1, "rb"), index_col=0, squeeze=True), + ps.read_excel(open(path1, "rb"), index_col=0, squeeze=True), pd.read_excel(open(path1, "rb"), index_col=0, squeeze=True), ) - if LooseVersion(pyspark.__version__) >= LooseVersion("3.0.0"): - self.assert_eq(pp.read_excel(path1), pd.read_excel(path1)) - self.assert_eq(pp.read_excel(path1, index_col=0), pd.read_excel(path1, index_col=0)) - self.assert_eq( - pp.read_excel(path1, index_col=0, squeeze=True), - pd.read_excel(path1, index_col=0, squeeze=True), - ) + self.assert_eq(ps.read_excel(path1), pd.read_excel(path1)) + self.assert_eq(ps.read_excel(path1, index_col=0), pd.read_excel(path1, index_col=0)) + self.assert_eq( + ps.read_excel(path1, index_col=0, squeeze=True), + pd.read_excel(path1, index_col=0, squeeze=True), + ) - self.assert_eq(pp.read_excel(tmp), pd.read_excel(path1)) + self.assert_eq(ps.read_excel(tmp), pd.read_excel(path1)) - path2 = "{}/file2.xlsx".format(tmp) - self.test_pdf[["i32"]].to_excel(path2) - self.assert_eq( - pp.read_excel(tmp, index_col=0).sort_index(), - pd.concat( - [pd.read_excel(path1, index_col=0), pd.read_excel(path2, index_col=0)] - ).sort_index(), - ) - self.assert_eq( - pp.read_excel(tmp, index_col=0, squeeze=True).sort_index(), - pd.concat( - [ - pd.read_excel(path1, index_col=0, squeeze=True), - pd.read_excel(path2, index_col=0, squeeze=True), - ] - ).sort_index(), - ) - else: - self.assertRaises(ValueError, lambda: pp.read_excel(tmp)) + path2 = "{}/file2.xlsx".format(tmp) + self.test_pdf[["i32"]].to_excel(path2) + self.assert_eq( + ps.read_excel(tmp, index_col=0).sort_index(), + pd.concat( + [pd.read_excel(path1, index_col=0), pd.read_excel(path2, index_col=0)] + ).sort_index(), + ) + self.assert_eq( + ps.read_excel(tmp, index_col=0, squeeze=True).sort_index(), + pd.concat( + [ + pd.read_excel(path1, index_col=0, squeeze=True), + pd.read_excel(path2, index_col=0, squeeze=True), + ] + ).sort_index(), + ) with self.temp_dir() as tmp: path1 = "{}/file1.xlsx".format(tmp) @@ -307,79 +299,76 @@ def test_read_excel(self): ) for sheet_name in sheet_names: - kdfs = pp.read_excel(open(path1, "rb"), sheet_name=sheet_name, index_col=0) + kdfs = ps.read_excel(open(path1, "rb"), sheet_name=sheet_name, index_col=0) self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"]) self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"]) - kdfs = pp.read_excel( + kdfs = ps.read_excel( open(path1, "rb"), sheet_name=sheet_name, index_col=0, squeeze=True ) self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"]) self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"]) - if LooseVersion(pyspark.__version__) >= LooseVersion("3.0.0"): - self.assert_eq( - pp.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"), - pdfs1["Sheet_name_2"], - ) + self.assert_eq( + ps.read_excel(tmp, index_col=0, sheet_name="Sheet_name_2"), + pdfs1["Sheet_name_2"], + ) - for sheet_name in sheet_names: - kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0) - self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"]) - self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"]) + for sheet_name in sheet_names: + kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0) + self.assert_eq(kdfs["Sheet_name_1"], pdfs1["Sheet_name_1"]) + self.assert_eq(kdfs["Sheet_name_2"], pdfs1["Sheet_name_2"]) - kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True) - self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"]) - self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"]) + kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True) + self.assert_eq(kdfs["Sheet_name_1"], pdfs1_squeezed["Sheet_name_1"]) + self.assert_eq(kdfs["Sheet_name_2"], pdfs1_squeezed["Sheet_name_2"]) - path2 = "{}/file2.xlsx".format(tmp) - with pd.ExcelWriter(path2) as writer: - self.test_pdf.to_excel(writer, sheet_name="Sheet_name_1") - self.test_pdf[["i32"]].to_excel(writer, sheet_name="Sheet_name_2") + path2 = "{}/file2.xlsx".format(tmp) + with pd.ExcelWriter(path2) as writer: + self.test_pdf.to_excel(writer, sheet_name="Sheet_name_1") + self.test_pdf[["i32"]].to_excel(writer, sheet_name="Sheet_name_2") - pdfs2 = pd.read_excel(path2, sheet_name=None, index_col=0) - pdfs2_squeezed = pd.read_excel(path2, sheet_name=None, index_col=0, squeeze=True) + pdfs2 = pd.read_excel(path2, sheet_name=None, index_col=0) + pdfs2_squeezed = pd.read_excel(path2, sheet_name=None, index_col=0, squeeze=True) + self.assert_eq( + ps.read_excel(tmp, sheet_name="Sheet_name_2", index_col=0).sort_index(), + pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(), + ) + self.assert_eq( + ps.read_excel( + tmp, sheet_name="Sheet_name_2", index_col=0, squeeze=True + ).sort_index(), + pd.concat( + [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]] + ).sort_index(), + ) + + for sheet_name in sheet_names: + kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0) self.assert_eq( - pp.read_excel(tmp, sheet_name="Sheet_name_2", index_col=0).sort_index(), + kdfs["Sheet_name_1"].sort_index(), + pd.concat([pdfs1["Sheet_name_1"], pdfs2["Sheet_name_1"]]).sort_index(), + ) + self.assert_eq( + kdfs["Sheet_name_2"].sort_index(), pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(), ) + + kdfs = ps.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True) self.assert_eq( - pp.read_excel( - tmp, sheet_name="Sheet_name_2", index_col=0, squeeze=True + kdfs["Sheet_name_1"].sort_index(), + pd.concat( + [pdfs1_squeezed["Sheet_name_1"], pdfs2_squeezed["Sheet_name_1"]] ).sort_index(), + ) + self.assert_eq( + kdfs["Sheet_name_2"].sort_index(), pd.concat( [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]] ).sort_index(), ) - for sheet_name in sheet_names: - kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0) - self.assert_eq( - kdfs["Sheet_name_1"].sort_index(), - pd.concat([pdfs1["Sheet_name_1"], pdfs2["Sheet_name_1"]]).sort_index(), - ) - self.assert_eq( - kdfs["Sheet_name_2"].sort_index(), - pd.concat([pdfs1["Sheet_name_2"], pdfs2["Sheet_name_2"]]).sort_index(), - ) - - kdfs = pp.read_excel(tmp, sheet_name=sheet_name, index_col=0, squeeze=True) - self.assert_eq( - kdfs["Sheet_name_1"].sort_index(), - pd.concat( - [pdfs1_squeezed["Sheet_name_1"], pdfs2_squeezed["Sheet_name_1"]] - ).sort_index(), - ) - self.assert_eq( - kdfs["Sheet_name_2"].sort_index(), - pd.concat( - [pdfs1_squeezed["Sheet_name_2"], pdfs2_squeezed["Sheet_name_2"]] - ).sort_index(), - ) - else: - self.assertRaises(ValueError, lambda: pp.read_excel(tmp)) - def test_read_orc(self): with self.temp_dir() as tmp: path = "{}/file1.orc".format(tmp) @@ -393,50 +382,50 @@ def test_read_orc(self): orc_file_path = glob.glob(os.path.join(path, "*.orc"))[0] expected = data.reset_index()[data.columns] - actual = pp.read_orc(path) + actual = ps.read_orc(path) self.assertPandasEqual(expected, actual.to_pandas()) # columns columns = ["i32", "i64"] expected = data.reset_index()[columns] - actual = pp.read_orc(path, columns=columns) + actual = ps.read_orc(path, columns=columns) self.assertPandasEqual(expected, actual.to_pandas()) # index_col expected = data.set_index("i32") - actual = pp.read_orc(path, index_col="i32") + actual = ps.read_orc(path, index_col="i32") self.assert_eq(actual, expected) expected = data.set_index(["i32", "f"]) - actual = pp.read_orc(path, index_col=["i32", "f"]) + actual = ps.read_orc(path, index_col=["i32", "f"]) self.assert_eq(actual, expected) # index_col with columns expected = data.set_index("i32")[["i64", "bhello"]] - actual = pp.read_orc(path, index_col=["i32"], columns=["i64", "bhello"]) + actual = ps.read_orc(path, index_col=["i32"], columns=["i64", "bhello"]) self.assert_eq(actual, expected) expected = data.set_index(["i32", "f"])[["bhello", "i64"]] - actual = pp.read_orc(path, index_col=["i32", "f"], columns=["bhello", "i64"]) + actual = ps.read_orc(path, index_col=["i32", "f"], columns=["bhello", "i64"]) self.assert_eq(actual, expected) msg = "Unknown column name 'i'" with self.assertRaises(ValueError, msg=msg): - pp.read_orc(path, columns="i32") + ps.read_orc(path, columns="i32") msg = "Unknown column name 'i34'" with self.assertRaises(ValueError, msg=msg): - pp.read_orc(path, columns=["i34", "i64"]) + ps.read_orc(path, columns=["i34", "i64"]) def test_orc_write(self): with self.temp_dir() as tmp: pdf = self.test_pdf - expected = pp.DataFrame(pdf) + expected = ps.DataFrame(pdf) # Write out partitioned by one column expected.to_orc(tmp, mode="overwrite", partition_cols="i32") # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_orc(tmp) + actual = ps.read_orc(tmp) self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( @@ -448,7 +437,7 @@ def test_orc_write(self): expected.to_orc(tmp, mode="overwrite", partition_cols=["i32", "bhello"]) # Reset column order, as once the data is written out, Spark rearranges partition # columns to appear first. - actual = pp.read_orc(tmp) + actual = ps.read_orc(tmp) self.assertFalse((actual.columns == self.test_column_order).all()) actual = actual[self.test_column_order] self.assert_eq( diff --git a/python/pyspark/pandas/tests/test_default_index.py b/python/pyspark/pandas/tests/test_default_index.py index 4075de4f11..838e04a9eb 100644 --- a/python/pyspark/pandas/tests/test_default_index.py +++ b/python/pyspark/pandas/tests/test_default_index.py @@ -18,10 +18,10 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class DefaultIndexTest(ReusedSQLTestCase): +class DefaultIndexTest(PandasOnSparkTestCase): def test_default_index_sequence(self): with ps.option_context("compute.default_index_type", "sequence"): sdf = self.spark.range(1000) diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index c341892c51..7198a1d5d0 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -21,11 +21,11 @@ import pandas as pd import pyspark.pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils from pyspark.pandas.window import Expanding +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class ExpandingTest(ReusedSQLTestCase, TestUtils): +class ExpandingTest(PandasOnSparkTestCase, TestUtils): def _test_expanding_func(self, f): pser = pd.Series([1, 2, 3], index=np.random.rand(3)) kser = ps.from_pandas(pser) diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py index 9fb61d02cc..17dc2bcd8b 100644 --- a/python/pyspark/pandas/tests/test_extension.py +++ b/python/pyspark/pandas/tests/test_extension.py @@ -21,7 +21,7 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.testing.utils import assert_produces_warning, ReusedSQLTestCase +from pyspark.testing.pandasutils import assert_produces_warning, PandasOnSparkTestCase from pyspark.pandas.extensions import ( register_dataframe_accessor, register_series_accessor, @@ -66,7 +66,7 @@ def check_length(self, col=None): raise ValueError(str(e)) -class ExtensionTest(ReusedSQLTestCase): +class ExtensionTest(PandasOnSparkTestCase): @property def pdf(self): return pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_frame_spark.py b/python/pyspark/pandas/tests/test_frame_spark.py index 3dca25f6ab..6a226a740f 100644 --- a/python/pyspark/pandas/tests/test_frame_spark.py +++ b/python/pyspark/pandas/tests/test_frame_spark.py @@ -15,22 +15,21 @@ # limitations under the License. # -from distutils.version import LooseVersion import os import pandas as pd -import pyspark -from pyspark import pandas as pp -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils, TestUtils +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.testing.sqlutils import SQLTestUtils -class SparkFrameMethodsTest(ReusedSQLTestCase, SQLTestUtils, TestUtils): +class SparkFrameMethodsTest(PandasOnSparkTestCase, SQLTestUtils, TestUtils): def test_frame_apply_negative(self): with self.assertRaisesRegex( ValueError, "The output of the function.* pyspark.sql.DataFrame.*int" ): - pp.range(10).spark.apply(lambda scol: 1) + ps.range(10).spark.apply(lambda scol: 1) def test_hint(self): pdf1 = pd.DataFrame( @@ -39,13 +38,10 @@ def test_hint(self): pdf2 = pd.DataFrame( {"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]} ).set_index("rkey") - kdf1 = pp.from_pandas(pdf1) - kdf2 = pp.from_pandas(pdf2) + kdf1 = ps.from_pandas(pdf1) + kdf2 = ps.from_pandas(pdf2) - if LooseVersion(pyspark.__version__) >= LooseVersion("3.0"): - hints = ["broadcast", "merge", "shuffle_hash", "shuffle_replicate_nl"] - else: - hints = ["broadcast"] + hints = ["broadcast", "merge", "shuffle_hash", "shuffle_replicate_nl"] for hint in hints: self.assert_eq( @@ -68,7 +64,7 @@ def test_hint(self): ) def test_repartition(self): - kdf = pp.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]}) + kdf = ps.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]}) num_partitions = kdf.to_spark().rdd.getNumPartitions() + 1 num_partitions += 1 @@ -91,7 +87,7 @@ def test_repartition(self): self.assert_eq(kdf2.sort_index(), (kdf + 1).spark.repartition(num_partitions).sort_index()) # Reserves MultiIndex - kdf = pp.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]]) + kdf = ps.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]]) num_partitions = kdf.to_spark().rdd.getNumPartitions() + 1 new_kdf = kdf.spark.repartition(num_partitions) self.assertEqual(new_kdf.to_spark().rdd.getNumPartitions(), num_partitions) @@ -99,7 +95,7 @@ def test_repartition(self): def test_coalesce(self): num_partitions = 10 - kdf = pp.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]}) + kdf = ps.DataFrame({"age": [5, 5, 2, 2], "name": ["Bob", "Bob", "Alice", "Alice"]}) kdf = kdf.spark.repartition(num_partitions) num_partitions -= 1 @@ -122,7 +118,7 @@ def test_coalesce(self): self.assert_eq(kdf2.sort_index(), (kdf + 1).spark.coalesce(num_partitions).sort_index()) # Reserves MultiIndex - kdf = pp.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]]) + kdf = ps.DataFrame({"a": ["a", "b", "c"]}, index=[[1, 2, 3], [4, 5, 6]]) num_partitions -= 1 kdf = kdf.spark.repartition(num_partitions) @@ -134,13 +130,13 @@ def test_coalesce(self): def test_checkpoint(self): with self.temp_dir() as tmp: self.spark.sparkContext.setCheckpointDir(tmp) - kdf = pp.DataFrame({"a": ["a", "b", "c"]}) + kdf = ps.DataFrame({"a": ["a", "b", "c"]}) new_kdf = kdf.spark.checkpoint() self.assertIsNotNone(os.listdir(tmp)) self.assert_eq(kdf, new_kdf) def test_local_checkpoint(self): - kdf = pp.DataFrame({"a": ["a", "b", "c"]}) + kdf = ps.DataFrame({"a": ["a", "b", "c"]}) new_kdf = kdf.spark.local_checkpoint() self.assert_eq(kdf, new_kdf) diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index ca1f9e9f72..a6d006fad9 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -30,11 +30,11 @@ MissingPandasLikeDataFrameGroupBy, MissingPandasLikeSeriesGroupBy, ) -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils from pyspark.pandas.groupby import is_multi_agg_with_relabel +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class GroupByTest(ReusedSQLTestCase, TestUtils): +class GroupByTest(PandasOnSparkTestCase, TestUtils): def test_groupby_simple(self): pdf = pd.DataFrame( { diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py index 8298767f67..0d02c46b0e 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/test_indexing.py @@ -24,7 +24,7 @@ from pyspark import pandas as ps from pyspark.pandas.exceptions import SparkPandasIndexingError -from pyspark.pandas.testing.utils import ComparisonTestBase, ReusedSQLTestCase, compare_both +from pyspark.testing.pandasutils import ComparisonTestBase, PandasOnSparkTestCase, compare_both class BasicIndexingTest(ComparisonTestBase): @@ -153,7 +153,7 @@ def test_limitations(self): ) -class IndexingTest(ReusedSQLTestCase): +class IndexingTest(PandasOnSparkTestCase): @property def pdf(self): return pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_indexops_spark.py b/python/pyspark/pandas/tests/test_indexops_spark.py index ae659ac17f..831b764271 100644 --- a/python/pyspark/pandas/tests/test_indexops_spark.py +++ b/python/pyspark/pandas/tests/test_indexops_spark.py @@ -20,10 +20,11 @@ from pyspark.sql import functions as F from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class SparkIndexOpsMethodsTest(ReusedSQLTestCase, SQLTestUtils): +class SparkIndexOpsMethodsTest(PandasOnSparkTestCase, SQLTestUtils): @property def pser(self): return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x") diff --git a/python/pyspark/pandas/tests/test_internal.py b/python/pyspark/pandas/tests/test_internal.py index f93b24bbe1..f9e96cd995 100644 --- a/python/pyspark/pandas/tests/test_internal.py +++ b/python/pyspark/pandas/tests/test_internal.py @@ -22,10 +22,11 @@ SPARK_DEFAULT_INDEX_NAME, SPARK_INDEX_NAME_FORMAT, ) -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class InternalFrameTest(ReusedSQLTestCase, SQLTestUtils): +class InternalFrameTest(PandasOnSparkTestCase, SQLTestUtils): def test_from_pandas(self): pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py index 9172f045c2..e8787397e1 100644 --- a/python/pyspark/pandas/tests/test_namespace.py +++ b/python/pyspark/pandas/tests/test_namespace.py @@ -20,11 +20,12 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils from pyspark.pandas.namespace import _get_index_map +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class NamespaceTest(ReusedSQLTestCase, SQLTestUtils): +class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils): def test_from_pandas(self): pdf = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) kdf = ps.from_pandas(pdf) diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py index e278739c31..ce2bbe1702 100644 --- a/python/pyspark/pandas/tests/test_numpy_compat.py +++ b/python/pyspark/pandas/tests/test_numpy_compat.py @@ -23,10 +23,11 @@ from pyspark import pandas as ps from pyspark.pandas import set_option, reset_option from pyspark.pandas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils): +class NumPyCompatTest(PandasOnSparkTestCase, SQLTestUtils): blacklist = [ # Koalas does not currently support "conj", diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index d567bae3cd..a998414542 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -22,12 +22,11 @@ import pandas as pd import numpy as np -import pyspark - from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option from pyspark.pandas.frame import DataFrame -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils from pyspark.pandas.typedef.typehints import ( extension_dtypes, extension_dtypes_available, @@ -36,7 +35,7 @@ ) -class OpsOnDiffFramesEnabledTest(ReusedSQLTestCase, SQLTestUtils): +class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils): @classmethod def setUpClass(cls): super().setUpClass() @@ -1549,10 +1548,7 @@ def test_series_repeat(self): kser1 = ps.from_pandas(pser1) kser2 = ps.from_pandas(pser2) - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - self.assertRaises(ValueError, lambda: kser1.repeat(kser2)) - else: - self.assert_eq(kser1.repeat(kser2).sort_index(), pser1.repeat(pser2).sort_index()) + self.assert_eq(kser1.repeat(kser2).sort_index(), pser1.repeat(pser2).sort_index()) def test_series_ops(self): pser1 = pd.Series([1, 2, 3, 4, 5, 6, 7], name="x", index=[11, 12, 13, 14, 15, 16, 17]) @@ -1774,7 +1770,7 @@ def test_rank(self): ) -class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils): +class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py index 84c72cbbbb..ce4653868d 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py @@ -21,10 +21,11 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class OpsOnDiffFramesGroupByTest(ReusedSQLTestCase, SQLTestUtils): +class OpsOnDiffFramesGroupByTest(PandasOnSparkTestCase, SQLTestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py index 88cf84e95d..afd81854e8 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py @@ -22,10 +22,10 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class OpsOnDiffFramesGroupByExpandingTest(ReusedSQLTestCase, TestUtils): +class OpsOnDiffFramesGroupByExpandingTest(PandasOnSparkTestCase, TestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py index 8b7e3edb43..158af35f61 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py @@ -19,10 +19,10 @@ from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class OpsOnDiffFramesGroupByRollingTest(ReusedSQLTestCase, TestUtils): +class OpsOnDiffFramesGroupByRollingTest(PandasOnSparkTestCase, TestUtils): @classmethod def setUpClass(cls): super().setUpClass() diff --git a/python/pyspark/pandas/tests/test_repr.py b/python/pyspark/pandas/tests/test_repr.py index 7259639ae9..e2b1c166f1 100644 --- a/python/pyspark/pandas/tests/test_repr.py +++ b/python/pyspark/pandas/tests/test_repr.py @@ -15,17 +15,14 @@ # limitations under the License. # -from distutils.version import LooseVersion - import numpy as np -import pyspark from pyspark import pandas as ps from pyspark.pandas.config import set_option, reset_option, option_context -from pyspark.pandas.testing.utils import ReusedSQLTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class ReprTest(ReusedSQLTestCase): +class ReprTest(PandasOnSparkTestCase): max_display_count = 23 @classmethod @@ -82,26 +79,25 @@ def test_repr_series(self): kser = ps.range(ReprTest.max_display_count + 1).id.rename() self.assert_eq(repr(kser), repr(kser.to_pandas())) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - kser = ps.MultiIndex.from_tuples( - [(100 * i, i) for i in range(ReprTest.max_display_count)] - ).to_series() - self.assertTrue("Showing only the first" not in repr(kser)) - self.assert_eq(repr(kser), repr(kser.to_pandas())) + kser = ps.MultiIndex.from_tuples( + [(100 * i, i) for i in range(ReprTest.max_display_count)] + ).to_series() + self.assertTrue("Showing only the first" not in repr(kser)) + self.assert_eq(repr(kser), repr(kser.to_pandas())) + + kser = ps.MultiIndex.from_tuples( + [(100 * i, i) for i in range(ReprTest.max_display_count + 1)] + ).to_series() + self.assertTrue("Showing only the first" in repr(kser)) + self.assertTrue( + repr(kser).startswith(repr(kser.to_pandas().head(ReprTest.max_display_count))) + ) + with option_context("display.max_rows", None): kser = ps.MultiIndex.from_tuples( [(100 * i, i) for i in range(ReprTest.max_display_count + 1)] ).to_series() - self.assertTrue("Showing only the first" in repr(kser)) - self.assertTrue( - repr(kser).startswith(repr(kser.to_pandas().head(ReprTest.max_display_count))) - ) - - with option_context("display.max_rows", None): - kser = ps.MultiIndex.from_tuples( - [(100 * i, i) for i in range(ReprTest.max_display_count + 1)] - ).to_series() - self.assert_eq(repr(kser), repr(kser.to_pandas())) + self.assert_eq(repr(kser), repr(kser.to_pandas())) def test_repr_indexes(self): kidx = ps.range(ReprTest.max_display_count).index diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py index 1f3dfbe2d7..96665dfa01 100644 --- a/python/pyspark/pandas/tests/test_reshape.py +++ b/python/pyspark/pandas/tests/test_reshape.py @@ -21,14 +21,13 @@ import numpy as np import pandas as pd -import pyspark from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SPARK_CONF_ARROW_ENABLED from pyspark.pandas.utils import name_like_string +from pyspark.testing.pandasutils import PandasOnSparkTestCase -class ReshapeTest(ReusedSQLTestCase): +class ReshapeTest(PandasOnSparkTestCase): def test_get_dummies(self): for pdf_or_ps in [ pd.Series([1, 1, 1, 2, 2, 1, 3, 4]), @@ -111,41 +110,23 @@ def test_get_dummies_date_datetime(self): ) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8)) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8)) def test_get_dummies_boolean(self): pdf = pd.DataFrame({"b": [True, False, True]}) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8)) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8)) def test_get_dummies_decimal(self): pdf = pd.DataFrame({"d": [Decimal(1.0), Decimal(2.0), Decimal(1)]}) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) - self.assert_eq( - ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True - ) + self.assert_eq(ps.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8)) + self.assert_eq(ps.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True) def test_get_dummies_kwargs(self): # pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category') diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index f664b2ac9f..3827a6017e 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -19,11 +19,11 @@ import pandas as pd import pyspark.pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils from pyspark.pandas.window import Rolling -class RollingTest(ReusedSQLTestCase, TestUtils): +class RollingTest(PandasOnSparkTestCase, TestUtils): def test_rolling_error(self): with self.assertRaisesRegex(ValueError, "window must be >= 0"): ps.range(10).rolling(window=-1) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index b5960d65ab..eae26bc4c8 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -24,17 +24,17 @@ import numpy as np import pandas as pd -import pyspark from pyspark.ml.linalg import SparseVector from pyspark.sql import functions as F from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ( +from pyspark.testing.pandasutils import ( have_tabulate, - ReusedSQLTestCase, - SQLTestUtils, + PandasOnSparkTestCase, SPARK_CONF_ARROW_ENABLED, + tabulate_requirement_message, ) +from pyspark.testing.sqlutils import SQLTestUtils from pyspark.pandas.exceptions import PandasNotImplementedError from pyspark.pandas.missing.series import MissingPandasLikeSeries from pyspark.pandas.typedef.typehints import ( @@ -45,7 +45,7 @@ ) -class SeriesTest(ReusedSQLTestCase, SQLTestUtils): +class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): @property def pser(self): return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x") @@ -146,11 +146,7 @@ def test_empty_series(self): self.assert_eq(ps.from_pandas(pser_a), pser_a) kser_b = ps.from_pandas(pser_b) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(kser_b, pser_b) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kser_b, pser_b) + self.assert_eq(kser_b, pser_b) with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): self.assert_eq(ps.from_pandas(pser_a), pser_a) @@ -163,11 +159,7 @@ def test_all_null_series(self): self.assert_eq(ps.from_pandas(pser_a), pser_a) kser_b = ps.from_pandas(pser_b) - if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"): - self.assert_eq(kser_b, pser_b) - else: - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self.assert_eq(kser_b, pser_b) + self.assert_eq(kser_b, pser_b) with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): self.assert_eq(ps.from_pandas(pser_a), pser_a) @@ -628,7 +620,7 @@ def test_nunique(self): self.assertEqual(ps.Series(range(100)).nunique(approx=True), 103) self.assertEqual(ps.Series(range(100)).nunique(approx=True, rsd=0.01), 100) - def _test_value_counts(self): + def test_value_counts(self): # this is also containing test for Index & MultiIndex pser = pd.Series( [1, 2, 1, 3, 3, np.nan, 1, 4, 2, np.nan, 3, np.nan, 3, 1, 3], @@ -856,17 +848,6 @@ def _test_value_counts(self): almost=True, ) - def test_value_counts(self): - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - self._test_value_counts() - self.assertRaises( - RuntimeError, - lambda: ps.MultiIndex.from_tuples([("x", "a"), ("x", "b")]).value_counts(), - ) - else: - self._test_value_counts() - def test_nsmallest(self): sample_lst = [1, 2, 3, 4, np.nan, 6] pser = pd.Series(sample_lst, name="x") @@ -1891,14 +1872,8 @@ def test_udt(self): sparse_values = {0: 0.1, 1: 1.1} sparse_vector = SparseVector(len(sparse_values), sparse_values) pser = pd.Series([sparse_vector]) - - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): - kser = ps.from_pandas(pser) - self.assert_eq(kser, pser) - else: - kser = ps.from_pandas(pser) - self.assert_eq(kser, pser) + kser = ps.from_pandas(pser) + self.assert_eq(kser, pser) def test_repeat(self): pser = pd.Series(["a", "b", "c"], name="0", index=np.random.rand(3)) @@ -1913,10 +1888,7 @@ def test_repeat(self): pdf = pd.DataFrame({"a": ["a", "b", "c"], "rep": [10, 20, 30]}, index=np.random.rand(3)) kdf = ps.from_pandas(pdf) - if LooseVersion(pyspark.__version__) < LooseVersion("2.4"): - self.assertRaises(ValueError, lambda: kdf.a.repeat(kdf.rep)) - else: - self.assert_eq(kdf.a.repeat(kdf.rep).sort_index(), pdf.a.repeat(pdf.rep).sort_index()) + self.assert_eq(kdf.a.repeat(kdf.rep).sort_index(), pdf.a.repeat(pdf.rep).sort_index()) def test_take(self): pser = pd.Series([100, 200, 300, 400, 500], name="Koalas") @@ -2209,7 +2181,7 @@ def test_shape(self): self.assert_eq(pser.shape, kser.shape) - @unittest.skipIf(not have_tabulate, "tabulate not installed") + @unittest.skipIf(not have_tabulate, tabulate_requirement_message) def test_to_markdown(self): pser = pd.Series(["elk", "pig", "dog", "quetzal"], name="animal") kser = ps.from_pandas(pser) @@ -2407,10 +2379,6 @@ def test_dot(self): self.assert_eq((kdf["b"] * 10).dot(kdf), (pdf["b"] * 10).dot(pdf)) self.assert_eq((kdf["b"] * 10).dot(kdf + 1), (pdf["b"] * 10).dot(pdf + 1)) - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("3.0"), - "tail won't work properly with PySpark<3.0", - ) def test_tail(self): pser = pd.Series(range(1000), name="Koalas") kser = ps.from_pandas(pser) @@ -2508,10 +2476,6 @@ def test_hasnans(self): kser = ps.from_pandas(pser) self.assert_eq(pser.hasnans, kser.hasnans) - @unittest.skipIf( - LooseVersion(pyspark.__version__) < LooseVersion("3.0"), - "last_valid_index won't work properly with PySpark<3.0", - ) def test_last_valid_index(self): pser = pd.Series([250, 1.5, 320, 1, 0.3, None, None, None, None]) kser = ps.from_pandas(pser) diff --git a/python/pyspark/pandas/tests/test_series_conversion.py b/python/pyspark/pandas/tests/test_series_conversion.py index 2b19249c0d..18ce24de74 100644 --- a/python/pyspark/pandas/tests/test_series_conversion.py +++ b/python/pyspark/pandas/tests/test_series_conversion.py @@ -21,10 +21,11 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class SeriesConversionTest(ReusedSQLTestCase, SQLTestUtils): +class SeriesConversionTest(PandasOnSparkTestCase, SQLTestUtils): @property def pser(self): return pd.Series([1, 2, 3, 4, 5, 6, 7], name="x") diff --git a/python/pyspark/pandas/tests/test_series_datetime.py b/python/pyspark/pandas/tests/test_series_datetime.py index fc27c96edf..deb4497483 100644 --- a/python/pyspark/pandas/tests/test_series_datetime.py +++ b/python/pyspark/pandas/tests/test_series_datetime.py @@ -22,10 +22,11 @@ import pandas as pd from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class SeriesDateTimeTest(ReusedSQLTestCase, SQLTestUtils): +class SeriesDateTimeTest(PandasOnSparkTestCase, SQLTestUtils): @property def pdf1(self): date1 = pd.Series(pd.date_range("2012-1-1 12:45:31", periods=3, freq="M")) diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py index 053c4d79be..69a9ab3424 100644 --- a/python/pyspark/pandas/tests/test_series_string.py +++ b/python/pyspark/pandas/tests/test_series_string.py @@ -20,10 +20,11 @@ import re from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class SeriesStringTest(ReusedSQLTestCase, SQLTestUtils): +class SeriesStringTest(PandasOnSparkTestCase, SQLTestUtils): @property def pser(self): return pd.Series( diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py index 6d29beee97..6c3405f0f0 100644 --- a/python/pyspark/pandas/tests/test_sql.py +++ b/python/pyspark/pandas/tests/test_sql.py @@ -16,12 +16,12 @@ # from pyspark import pandas as ps -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils - from pyspark.sql.utils import ParseException +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils -class SQLTest(ReusedSQLTestCase, SQLTestUtils): +class SQLTest(PandasOnSparkTestCase, SQLTestUtils): def test_error_variable_not_exist(self): msg = "The key variable_foo in the SQL statement was not found.*" with self.assertRaisesRegex(ValueError, msg): diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index 1a1885c845..905add4baa 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -27,14 +27,11 @@ from pyspark import pandas as ps from pyspark.pandas.config import option_context -from pyspark.pandas.testing.utils import ( - ReusedSQLTestCase, - SQLTestUtils, - SPARK_CONF_ARROW_ENABLED, -) +from pyspark.testing.pandasutils import PandasOnSparkTestCase, SPARK_CONF_ARROW_ENABLED +from pyspark.testing.sqlutils import SQLTestUtils -class StatsTest(ReusedSQLTestCase, SQLTestUtils): +class StatsTest(PandasOnSparkTestCase, SQLTestUtils): def _test_stat_functions(self, pdf_or_pser, kdf_or_kser): functions = ["max", "min", "mean", "sum", "count"] for funcname in functions: diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 8ab1e0324b..2f4039ba20 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -17,17 +17,18 @@ import pandas as pd -from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils from pyspark.pandas.utils import ( lazy_property, validate_arguments_and_invoke_function, validate_bool_kwarg, ) +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils some_global_variable = 0 -class UtilsTest(ReusedSQLTestCase, SQLTestUtils): +class UtilsTest(PandasOnSparkTestCase, SQLTestUtils): # a dummy to_html version with an extra parameter that pandas does not support # used in test_validate_arguments_and_invoke_function diff --git a/python/pyspark/pandas/tests/test_window.py b/python/pyspark/pandas/tests/test_window.py index 742b3b9cbd..8c347b8687 100644 --- a/python/pyspark/pandas/tests/test_window.py +++ b/python/pyspark/pandas/tests/test_window.py @@ -25,10 +25,10 @@ MissingPandasLikeExpandingGroupby, MissingPandasLikeRollingGroupby, ) -from pyspark.pandas.testing.utils import ReusedSQLTestCase, TestUtils +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -class ExpandingRollingTest(ReusedSQLTestCase, TestUtils): +class ExpandingRollingTest(PandasOnSparkTestCase, TestUtils): def test_missing(self): kdf = ps.DataFrame({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]}) diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py new file mode 100644 index 0000000000..4a5bfe8d56 --- /dev/null +++ b/python/pyspark/testing/pandasutils.py @@ -0,0 +1,373 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import shutil +import tempfile +import unittest +import warnings +from contextlib import contextmanager +from distutils.version import LooseVersion + +import pandas as pd +from pandas.api.types import is_list_like +from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal + +from pyspark import pandas as ps +from pyspark.pandas.frame import DataFrame +from pyspark.pandas.indexes import Index +from pyspark.pandas.series import Series +from pyspark.pandas.utils import default_session, SPARK_CONF_ARROW_ENABLED +from pyspark.testing.sqlutils import SQLTestUtils + + +tabulate_requirement_message = None +try: + from tabulate import tabulate # noqa: F401 +except ImportError as e: + # If tabulate requirement is not satisfied, skip related tests. + tabulate_requirement_message = str(e) +have_tabulate = tabulate_requirement_message is None + +matplotlib_requirement_message = None +try: + import matplotlib # type: ignore # noqa: F401 +except ImportError as e: + # If matplotlib requirement is not satisfied, skip related tests. + matplotlib_requirement_message = str(e) +have_matplotlib = matplotlib_requirement_message is None + +plotly_requirement_message = None +try: + import plotly # type: ignore # noqa: F401 +except ImportError as e: + # If plotly requirement is not satisfied, skip related tests. + plotly_requirement_message = str(e) +have_plotly = plotly_requirement_message is None + + +class PandasOnSparkTestCase(unittest.TestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + cls.spark = default_session() + cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) + + @classmethod + def tearDownClass(cls): + # We don't stop Spark session to reuse across all tests. + # The Spark session will be started and stopped at PyTest session level. + # Please see databricks/koalas/conftest.py. + pass + + def assertPandasEqual(self, left, right, check_exact=True): + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_frame_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_column_type=("equiv" if len(left.columns) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_series_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + try: + assert_index_equal(left, right, check_exact=check_exact) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assertPandasAlmostEqual(self, left, right): + """ + This function checks if given pandas objects approximately same, + which means the conditions below: + - Both objects are nullable + - Compare floats rounding to the number of decimal places, 7 after + dropping missing values (NaN, NaT, None) + """ + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + msg = ( + "DataFrames are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + self.assertEqual(left.shape, right.shape, msg=msg) + for lcol, rcol in zip(left.columns, right.columns): + self.assertEqual(lcol, rcol, msg=msg) + for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + self.assertEqual(left.columns.names, right.columns.names, msg=msg) + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + msg = ( + "Series are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(left.name, right.name, msg=msg) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): + msg = ( + "MultiIndices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lval, rval in zip(left, right): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + msg = ( + "Indices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assert_eq(self, left, right, check_exact=True, almost=False): + """ + Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame + or Series, they are converted into pandas' and compared. + + :param left: object to compare + :param right: object to compare + :param check_exact: if this is False, the comparison is done less precisely. + :param almost: if this is enabled, the comparison is delegated to `unittest`'s + `assertAlmostEqual`. See its documentation for more details. + """ + lobj = self._to_pandas(left) + robj = self._to_pandas(right) + if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): + if almost: + self.assertPandasAlmostEqual(lobj, robj) + else: + self.assertPandasEqual(lobj, robj, check_exact=check_exact) + elif is_list_like(lobj) and is_list_like(robj): + self.assertTrue(len(left) == len(right)) + for litem, ritem in zip(left, right): + self.assert_eq(litem, ritem, check_exact=check_exact, almost=almost) + elif (lobj is not None and pd.isna(lobj)) and (robj is not None and pd.isna(robj)): + pass + else: + if almost: + self.assertAlmostEqual(lobj, robj) + else: + self.assertEqual(lobj, robj) + + @staticmethod + def _to_pandas(obj): + if isinstance(obj, (DataFrame, Series, Index)): + return obj.to_pandas() + else: + return obj + + +class TestUtils(object): + @contextmanager + def temp_dir(self): + tmp = tempfile.mkdtemp() + try: + yield tmp + finally: + shutil.rmtree(tmp) + + @contextmanager + def temp_file(self): + with self.temp_dir() as tmp: + yield tempfile.mktemp(dir=tmp) + + +class ComparisonTestBase(PandasOnSparkTestCase): + @property + def kdf(self): + return ps.from_pandas(self.pdf) + + @property + def pdf(self): + return self.kdf.to_pandas() + + +def compare_both(f=None, almost=True): + + if f is None: + return functools.partial(compare_both, almost=almost) + elif isinstance(f, bool): + return functools.partial(compare_both, almost=f) + + @functools.wraps(f) + def wrapped(self): + if almost: + compare = self.assertPandasAlmostEqual + else: + compare = self.assertPandasEqual + + for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.kdf)): + compare(result_pandas, result_spark.to_pandas()) + + return wrapped + + +@contextmanager +def assert_produces_warning( + expected_warning=Warning, + filter_level="always", + check_stacklevel=True, + raise_on_extra_warnings=True, +): + """ + Context manager for running code expected to either raise a specific + warning, or not raise any warnings. Verifies that the code raises the + expected warning, and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Notes + ----- + Replicated from pandas/_testing/_warnings.py. + + Parameters + ---------- + expected_warning : {Warning, False, None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + ... + >>> with assert_produces_warning(False): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + ... + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): # doctest: +SKIP + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning' + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + + saw_warning = False + warnings.simplefilter(filter_level) + yield w + extra_warnings = [] + + for actual_warning in w: + if expected_warning and issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel and issubclass( + actual_warning.category, (FutureWarning, DeprecationWarning) + ): + from inspect import getframeinfo, stack + + caller = getframeinfo(stack()[2][0]) + msg = ( + "Warning not set with correct stacklevel. ", + "File where warning is raised: {} != ".format(actual_warning.filename), + "{}. Warning message: {}".format(caller.filename, actual_warning.message), + ) + assert actual_warning.filename == caller.filename, msg + else: + extra_warnings.append( + ( + actual_warning.category.__name__, + actual_warning.message, + actual_warning.filename, + actual_warning.lineno, + ) + ) + if expected_warning: + msg = "Did not see expected warning of class {}".format(repr(expected_warning.__name__)) + assert saw_warning, msg + if raise_on_extra_warnings and extra_warnings: + raise AssertionError("Caused unexpected warning(s): {}".format(repr(extra_warnings))) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index f960aa4fee..bbd93d1c38 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import glob import os import struct @@ -29,13 +30,13 @@ try: import scipy.sparse # noqa: F401 have_scipy = True -except: +except ImportError: # No SciPy, but that's okay, we'll skip those tests pass try: import numpy as np # noqa: F401 have_numpy = True -except: +except ImportError: # No NumPy, but that's okay, we'll skip those tests pass diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 174c62183c..aa41b7ec2d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -62,6 +62,14 @@ private[spark] object Config extends Logging { .booleanConf .createWithDefault(true) + val KUBERNETES_DRIVER_OWN_PVC = + ConfigBuilder("spark.kubernetes.driver.ownPersistentVolumeClaim") + .doc("If true, driver pod becomes the owner of on-demand persistent volume claims " + + "instead of the executor pods") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val KUBERNETES_NAMESPACE = ConfigBuilder("spark.kubernetes.namespace") .doc("The namespace that will be used for running the driver and executor pods.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index c66756fd69..4e1647372e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import io.fabric8.kubernetes.api.model._ import org.apache.spark.deploy.k8s._ -import org.apache.spark.deploy.k8s.Constants.ENV_EXECUTOR_ID +import org.apache.spark.deploy.k8s.Constants.{ENV_EXECUTOR_ID, SPARK_APP_ID_LABEL} private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) extends KubernetesFeatureConfigStep { @@ -85,6 +85,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .withApiVersion("v1") .withNewMetadata() .withName(claimName) + .addToLabels(SPARK_APP_ID_LABEL, conf.sparkConf.getAppId) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala index 5ebd172f7d..d54f665a38 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -339,6 +339,9 @@ private[spark] class ExecutorPodsAllocator( resources .filter(_.getKind == "PersistentVolumeClaim") .foreach { resource => + if (conf.get(KUBERNETES_DRIVER_OWN_PVC) && driverPod.nonEmpty) { + addOwnerReference(driverPod.get, Seq(resource)) + } val pvc = resource.asInstanceOf[PersistentVolumeClaim] logInfo(s"Trying to create PersistentVolumeClaim ${pvc.getMetadata.getName} with " + s"StorageClass ${pvc.getSpec.getStorageClassName}") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 780b08bd0e..d5a4856d37 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -134,6 +134,13 @@ private[spark] class KubernetesClusterSchedulerBackend( } } + Utils.tryLogNonFatalError { + kubernetesClient + .persistentVolumeClaims() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .delete() + } + if (shouldDeleteExecutors) { Utils.tryLogNonFatalError { kubernetesClient diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index c958f9c387..5566687054 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -606,7 +606,13 @@ groupByClause ; groupingAnalytics - : (ROLLUP | CUBE | GROUPING SETS) '(' groupingSet (',' groupingSet)* ')' + : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')' + | GROUPING SETS '(' groupingElement (',' groupingElement)* ')' + ; + +groupingElement + : groupingAnalytics + | groupingSet ; groupingSet diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java new file mode 100644 index 0000000000..71e83002dd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.metric; + +import org.apache.spark.annotation.Evolving; + +import java.util.Arrays; +import java.text.DecimalFormat; + +/** + * Built-in `CustomMetric` that computes average of metric values. Note that please extend this + * class and override `name` and `description` to create your custom metric for real usage. + * + * @since 3.2.0 + */ +@Evolving +public abstract class CustomAvgMetric implements CustomMetric { + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + if (taskMetrics.length > 0) { + double average = ((double)Arrays.stream(taskMetrics).sum()) / taskMetrics.length; + return new DecimalFormat("#0.000").format(average); + } else { + return "0"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java new file mode 100644 index 0000000000..4c4151ad96 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomMetric.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.metric; + +import org.apache.spark.annotation.Evolving; + +/** + * A custom metric. Data source can define supported custom metrics using this interface. + * During query execution, Spark will collect the task metrics using {@link CustomTaskMetric} + * and combine the metrics at the driver side. How to combine task metrics is defined by the + * metric class with the same metric name. + * + * @since 3.2.0 + */ +@Evolving +public interface CustomMetric { + /** + * Returns the name of custom metric. + */ + String name(); + + /** + * Returns the description of custom metric. + */ + String description(); + + /** + * The initial value of this metric. + */ + long initialValue = 0L; + + /** + * Given an array of task metric values, returns aggregated final metric value. + */ + String aggregateTaskMetrics(long[] taskMetrics); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java new file mode 100644 index 0000000000..ba28e9b918 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.metric; + +import org.apache.spark.annotation.Evolving; + +import java.util.Arrays; + +/** + * Built-in `CustomMetric` that sums up metric values. Note that please extend this class + * and override `name` and `description` to create your custom metric for real usage. + * + * @since 3.2.0 + */ +@Evolving +public abstract class CustomSumMetric implements CustomMetric { + @Override + public String aggregateTaskMetrics(long[] taskMetrics) { + return String.valueOf(Arrays.stream(taskMetrics).sum()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java new file mode 100644 index 0000000000..1b6f04d927 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.metric; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.read.PartitionReader; + +/** + * A custom task metric. This is a logical representation of a metric reported by data sources + * at the executor side. During query execution, Spark will collect the task metrics per partition + * by {@link PartitionReader} and update internal metrics based on collected metric values. + * For streaming query, Spark will collect and combine metrics for a final result per micro batch. + *

+ * The metrics will be gathered during query execution back to the driver and then combined. How + * the task metrics are combined is defined by corresponding {@link CustomMetric} with same metric + * name. The final result will be shown up in the data source scan operator in Spark UI. + * + * @since 3.2.0 + */ +@Evolving +public interface CustomTaskMetric { + /** + * Returns the name of custom task metric. + */ + String name(); + + /** + * Returns the long value of custom task metric. + */ + long value(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java index d6cf070cf4..5286bbf9f8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java @@ -21,7 +21,7 @@ import java.io.IOException; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.CustomTaskMetric; +import org.apache.spark.sql.connector.metric.CustomTaskMetric; /** * A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java index b70a656c49..0c009f5c56 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Scan.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.CustomMetric; +import org.apache.spark.sql.connector.metric.CustomMetric; import org.apache.spark.sql.connector.read.streaming.ContinuousStream; import org.apache.spark.sql.connector.read.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f149c9bb0c..fe48670cb3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -19,12 +19,16 @@ import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY; +import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS; + /** * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. @@ -172,6 +176,10 @@ public ArrowColumnVector(ValueVector vector) { } } else if (vector instanceof NullVector) { accessor = new NullAccessor((NullVector) vector); + } else if (vector instanceof IntervalYearVector) { + accessor = new IntervalYearAccessor((IntervalYearVector) vector); + } else if (vector instanceof IntervalDayVector) { + accessor = new IntervalDayAccessor((IntervalDayVector) vector); } else { throw new UnsupportedOperationException(); } @@ -508,4 +516,37 @@ private static class NullAccessor extends ArrowVectorAccessor { super(vector); } } + + private static class IntervalYearAccessor extends ArrowVectorAccessor { + + private final IntervalYearVector accessor; + + IntervalYearAccessor(IntervalYearVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class IntervalDayAccessor extends ArrowVectorAccessor { + + private final IntervalDayVector accessor; + private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); + + IntervalDayAccessor(IntervalDayVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + long getLong(int rowId) { + accessor.get(rowId, intervalDayHolder); + return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY), + intervalDayHolder.milliseconds * MICROS_PER_MILLIS); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index b55d1b725f..ccf0a50b73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -304,18 +304,23 @@ object CatalystTypeConverters { row.getUTF8String(column).toString } - private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { - override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) + private object DateConverter extends CatalystTypeConverter[Any, Date, Any] { + override def toCatalystImpl(scalaValue: Any): Int = scalaValue match { + case d: Date => DateTimeUtils.fromJavaDate(d) + case l: LocalDate => DateTimeUtils.localDateToDays(l) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the ${DateType.sql} type") + } override def toScala(catalystValue: Any): Date = if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) override def toScalaImpl(row: InternalRow, column: Int): Date = DateTimeUtils.toJavaDate(row.getInt(column)) } - private object LocalDateConverter extends CatalystTypeConverter[LocalDate, LocalDate, Any] { - override def toCatalystImpl(scalaValue: LocalDate): Int = { - DateTimeUtils.localDateToDays(scalaValue) - } + private object LocalDateConverter extends CatalystTypeConverter[Any, LocalDate, Any] { + override def toCatalystImpl(scalaValue: Any): Int = + DateConverter.toCatalystImpl(scalaValue) override def toScala(catalystValue: Any): LocalDate = { if (catalystValue == null) null else DateTimeUtils.daysToLocalDate(catalystValue.asInstanceOf[Int]) @@ -324,9 +329,14 @@ object CatalystTypeConverters { DateTimeUtils.daysToLocalDate(row.getInt(column)) } - private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { - override def toCatalystImpl(scalaValue: Timestamp): Long = - DateTimeUtils.fromJavaTimestamp(scalaValue) + private object TimestampConverter extends CatalystTypeConverter[Any, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Any): Long = scalaValue match { + case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t) + case i: Instant => DateTimeUtils.instantToMicros(i) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the ${TimestampType.sql} type") + } override def toScala(catalystValue: Any): Timestamp = if (catalystValue == null) null else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) @@ -334,9 +344,9 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object InstantConverter extends CatalystTypeConverter[Instant, Instant, Any] { - override def toCatalystImpl(scalaValue: Instant): Long = - DateTimeUtils.instantToMicros(scalaValue) + private object InstantConverter extends CatalystTypeConverter[Any, Instant, Any] { + override def toCatalystImpl(scalaValue: Any): Long = + TimestampConverter.toCatalystImpl(scalaValue) override def toScala(catalystValue: Any): Instant = if (catalystValue == null) null else DateTimeUtils.microsToInstant(catalystValue.asInstanceOf[Long]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2c146c7de..87b8d52ac2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,9 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.trees.TreePattern.{ - EXPRESSION_WITH_RANDOM_SEED, NATURAL_LIKE_JOIN, WINDOW_EXPRESSION -} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -2179,7 +2177,8 @@ class Analyzer(override val catalogManager: CatalogManager) * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { - plan transformExpressions { + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, + EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => @@ -2196,7 +2195,8 @@ class Analyzer(override val catalogManager: CatalogManager) /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -3790,9 +3790,9 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan resolveOperators { + plan.resolveOperatorsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, FILTER), ruleId) { case f @ Filter(_, a: Aggregate) if f.resolved => - f transformExpressions { + f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { case s: SubqueryExpression if s.children.nonEmpty => // Collect the aliases from output of aggregate. val outerAliases = a.aggregateExpressions collect { case a: Alias => a } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index e9673d7f20..1f3f762662 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -35,7 +35,7 @@ trait AliasHelper { protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = { // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression or PythonUDF, and create a map from the alias to the expression - val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect { + val aliasMap = plan.aggregateExpressions.collect { case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e)).isEmpty => (a.toAttribute, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5d799c768a..30317c9e91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + final override val nodePatterns: Seq[TreePattern] = Seq(CAST) + override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index de4b874637..1c185dd316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike trait DynamicPruning extends Predicate @@ -69,6 +70,8 @@ case class DynamicPruningSubquery( pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType } + final override def nodePatternsInternal: Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY) + override def toString: String = s"dynamicpruning#${exprId.id} $conditionString" override lazy val canonicalized: DynamicPruning = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 03b5517f6d..a6be98c8a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -42,7 +42,8 @@ case class ProjectionOverSchema(schema: StructType) { getProjection(a.child).map(p => (p, p.dataType)).map { case (projection, ArrayType(projSchema @ StructType(_), _)) => // For case-sensitivity aware field resolution, we should take `ordinal` which - // points to correct struct field. + // points to correct struct field, because `ExtractValue` actually does column + // name resolving correctly. val selectedField = a.child.dataType.asInstanceOf[ArrayType] .elementType.asInstanceOf[StructType](a.ordinal) val prunedField = projSchema(selectedField.name) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 4ee6488c92..30093ef085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -22,9 +22,12 @@ import org.apache.spark.sql.types._ object SchemaPruning extends SQLConfHelper { /** - * Filters the schema by the requested fields. For example, if the schema is struct, - * and given requested field are "a", the field "b" is pruned in the returned schema. - * Note that schema field ordering at original schema is still preserved in pruned schema. + * Prunes the nested schema by the requested fields. For example, if the schema is: + * `id int, s struct`, and given requested field "s.a", the inner field "b" + * is pruned in the returned schema: `id int, s struct`. + * Note that: + * 1. The schema field ordering at original schema is still preserved in pruned schema. + * 2. The top-level fields are not pruned here. */ def pruneDataSchema( dataSchema: StructType, @@ -34,11 +37,10 @@ object SchemaPruning extends SQLConfHelper { // in the resulting schema may differ from their ordering in the logical relation's // original schema val mergedSchema = requestedRootFields - .map { case root: RootField => StructType(Array(root.field)) } + .map { root: RootField => StructType(Array(root.field)) } .reduceLeft(_ merge _) - val dataSchemaFieldNames = dataSchema.fieldNames.toSet val mergedDataSchema = - StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name)))) + StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d))) // Sort the fields of mergedDataSchema according to their order in dataSchema, // recursively. This makes mergedDataSchema a pruned schema of dataSchema sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 4fc0256bce..8ae24e5135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) - case _: YearMonthIntervalType => DivideYMInterval(sum, count) - case _: DayTimeIntervalType => DivideDTInterval(sum, count) + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count)) case _ => Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1d13155ef6..dfdd828d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false + final override val nodePatterns: Seq[TreePattern] = Seq(COUNT) + // Return data type. override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8c70c86aa1..281734c6f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -80,14 +80,6 @@ object AggregateExpression { filter, NamedExpression.newExprId) } - - def containsAggregate(expr: Expression): Boolean = { - expr.find(isAggregate).isDefined - } - - def isAggregate(expr: Expression): Boolean = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 954a4b9fc1..10b4a7be30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, + UNARY_POSITIVE} import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression) override def dataType: DataType = child.dataType + final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => c) @@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType + final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 125e796a98..a408280a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.time.ZoneId +import java.time.{Duration, Period, ZoneId} import java.util.Comparator import scala.collection.mutable @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) + final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) + override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { TypeCheckResult.TypeCheckSuccess @@ -2484,8 +2487,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran The start and stop expressions must resolve to the same type. If start and stop expressions resolve to the 'date' or 'timestamp' type - then the step expression must resolve to the 'interval' type, otherwise to the same type - as the start and stop expressions. + then the step expression must resolve to the 'interval' or 'year-month interval' or + 'day-time interval' type, otherwise to the same type as the start and stop expressions. """, arguments = """ Arguments: @@ -2504,6 +2507,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran [5,4,3,2,1] > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); [2018-01-01,2018-02-01,2018-03-01] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval '0-1' year to month); + [2018-01-01,2018-02-01,2018-03-01] """, group = "array_funcs", since = "2.4.0" @@ -2550,8 +2555,13 @@ case class Sequence( val typesCorrect = startType.sameType(stop.dataType) && (startType match { - case TimestampType | DateType => - stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case TimestampType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || + YearMonthIntervalType.acceptsType(stepType) || + DayTimeIntervalType.acceptsType(stepType) + case DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || + YearMonthIntervalType.acceptsType(stepType) case _: IntegralType => stepOpt.isEmpty || stepType.sameType(startType) case _ => false @@ -2561,29 +2571,51 @@ case class Sequence( TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"$prettyName only supports integral, timestamp or date types") + s""" + |$prettyName uses the wrong parameter type. The parameter type must conform to: + |1. The start and stop expressions must resolve to the same type. + |2. If start and stop expressions resolve to the 'date' or 'timestamp' type + |then the step expression must resolve to the 'interval' or + |'${YearMonthIntervalType.typeName}' or '${DayTimeIntervalType.typeName}' type, + |otherwise to the same type as the start and stop expressions. + """.stripMargin) } } - def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + private def isNotIntervalType(expr: Expression) = expr.dataType match { + case CalendarIntervalType | YearMonthIntervalType | DayTimeIntervalType => false + case _ => true + } + + def coercibleChildren: Seq[Expression] = children.filter(isNotIntervalType) def castChildrenTo(widerType: DataType): Expression = Sequence( Cast(start, widerType), Cast(stop, widerType), - stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + stepOpt.map(step => if (isNotIntervalType(step)) Cast(step, widerType) else step), timeZoneId) - @transient private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: InternalSequence = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) new IntegralSequenceImpl(iType)(ct, iType.integral) case TimestampType => - new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId) + if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { + new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId) + } else if (YearMonthIntervalType.acceptsType(stepOpt.get.dataType)) { + new PeriodSequenceImpl[Long](LongType, 1, identity, zoneId) + } else { + new DurationSequenceImpl[Long](LongType, 1, identity, zoneId) + } case DateType => - new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + } else { + new PeriodSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + } } override def eval(input: InternalRow): Any = { @@ -2666,7 +2698,7 @@ object Sequence { } } - private trait SequenceImpl { + private trait InternalSequence { def eval(start: Any, stop: Any, step: Any): Any def genCode( @@ -2681,7 +2713,7 @@ object Sequence { } private class IntegralSequenceImpl[T: ClassTag] - (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + (elemType: IntegralType)(implicit num: Integral[T]) extends InternalSequence { override val defaultStep: DefaultStep = new DefaultStep( (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], @@ -2695,7 +2727,7 @@ object Sequence { val stop = input2.asInstanceOf[T] val step = input3.asInstanceOf[T] - var i: Int = getSequenceLength(start, stop, step) + var i: Int = getSequenceLength(start, stop, step, step) val arr = new Array[T](i) while (i > 0) { i -= 1 @@ -2713,7 +2745,7 @@ object Sequence { elemType: String): String = { val i = ctx.freshName("i") s""" - |${genSequenceLengthCode(ctx, start, stop, step, i)} + |${genSequenceLengthCode(ctx, start, stop, step, step, i)} |$arr = new $elemType[$i]; |while ($i > 0) { | $i--; @@ -2723,32 +2755,105 @@ object Sequence { } } + private class PeriodSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + YearMonthIntervalType, + Period.of(0, 1, 0)) + + val intervalType: DataType = YearMonthIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + (input.asInstanceOf[Int], 0, 0) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = $step; + |final int $stepDays = 0; + |final long $stepMicros = 0L; + """.stripMargin + } + } + + private class DurationSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + DayTimeIntervalType, + Duration.ofDays(1)) + + val intervalType: DataType = DayTimeIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + (0, 0, input.asInstanceOf[Long]) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = 0; + |final int $stepDays = 0; + |final long $stepMicros = $step; + """.stripMargin + } + } + private class TemporalSequenceImpl[T: ClassTag] (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) - (implicit num: Integral[T]) extends SequenceImpl { + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], CalendarIntervalType, new CalendarInterval(0, 1, 0)) + val intervalType: DataType = CalendarIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + val step = input.asInstanceOf[CalendarInterval] + (step.months, step.days, step.microseconds) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = $step.months; + |final int $stepDays = $step.days; + |final long $stepMicros = $step.microseconds; + """.stripMargin + } + } + + private abstract class InternalSequenceBase[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequence { + + val defaultStep: DefaultStep + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) - private val microsPerDay = HOURS_PER_DAY * MICROS_PER_HOUR // We choose a minimum days(28) in one month to calculate the `intervalStepInMicros` // in order to make sure the estimated array length is long enough - private val microsPerMonth = 28 * microsPerDay + private val microsPerMonth = 28 * MICROS_PER_DAY + + protected val intervalType: DataType + + protected def splitStep(input: Any): (Int, Int, Long) override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { val start = input1.asInstanceOf[T] val stop = input2.asInstanceOf[T] - val step = input3.asInstanceOf[CalendarInterval] - val stepMonths = step.months - val stepDays = step.days - val stepMicros = step.microseconds + val (stepMonths, stepDays, stepMicros) = splitStep(input3) if (scale == MICROS_PER_DAY && stepMonths == 0 && stepDays == 0) { throw new IllegalArgumentException( - "sequence step must be a day interval if start and end values are dates") + s"sequence step must be a day ${intervalType.typeName} if start and end values are dates") } if (stepMonths == 0 && stepMicros == 0 && scale == MICROS_PER_DAY) { @@ -2763,11 +2868,12 @@ object Sequence { // To estimate the resulted array length we need to make assumptions // about a month length in days and a day length in microseconds val intervalStepInMicros = - stepMicros + stepMonths * microsPerMonth + stepDays * microsPerDay + stepMicros + stepMonths * microsPerMonth + stepDays * MICROS_PER_DAY val startMicros: Long = num.toLong(start) * scale val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = - getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + getSequenceLength(startMicros, stopMicros, input3, intervalStepInMicros) val stepSign = if (stopMicros >= startMicros) +1 else -1 val exclusiveItem = stopMicros + stepSign @@ -2787,6 +2893,9 @@ object Sequence { } } + protected def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String + override def genCode( ctx: CodegenContext, start: String, @@ -2811,25 +2920,27 @@ object Sequence { val sequenceLengthCode = s""" |final long $intervalInMicros = - | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${microsPerDay}L; - |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} - """.stripMargin + | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${MICROS_PER_DAY}L; + |${genSequenceLengthCode( + ctx, startMicros, stopMicros, step, intervalInMicros, arrLength)} + """.stripMargin val check = if (scale == MICROS_PER_DAY) { s""" |if ($stepMonths == 0 && $stepDays == 0) { | throw new IllegalArgumentException( - | "sequence step must be a day interval if start and end values are dates"); + | "sequence step must be a day ${intervalType.typeName} " + + | "if start and end values are dates"); |} - """.stripMargin + """.stripMargin } else { "" } + val stepSplits = stepSplitCode(stepMonths, stepDays, stepMicros, step) + s""" - |final int $stepMonths = $step.months; - |final int $stepDays = $step.days; - |final long $stepMicros = $step.microseconds; + |$stepSplits | |$check | @@ -2866,15 +2977,16 @@ object Sequence { } } - private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + private def getSequenceLength[U](start: U, stop: U, step: Any, estimatedStep: U) + (implicit num: Integral[U]): Int = { import num._ require( - (step > num.zero && start <= stop) - || (step < num.zero && start >= stop) - || (step == num.zero && start == stop), + (estimatedStep > num.zero && start <= stop) + || (estimatedStep < num.zero && start >= stop) + || (estimatedStep == num.zero && start == stop), s"Illegal sequence boundaries: $start to $stop by $step") - val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong require( len <= MAX_ROUNDED_ARRAY_LENGTH, @@ -2888,16 +3000,17 @@ object Sequence { start: String, stop: String, step: String, + estimatedStep: String, len: String): String = { val longLen = ctx.freshName("longLen") s""" - |if (!(($step > 0 && $start <= $stop) || - | ($step < 0 && $start >= $stop) || - | ($step == 0 && $start == $stop))) { + |if (!(($estimatedStep > 0 && $start <= $stop) || + | ($estimatedStep < 0 && $start >= $stop) || + | ($estimatedStep == 0 && $start == $stop))) { | throw new IllegalArgumentException( | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); |} - |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep; |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { | throw new IllegalArgumentException( | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e708d56cd8..3e356f1e8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern} import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def third: Expression = falseValue override def nullable: Boolean = trueValue.nullable || falseValue.nullable + final override val nodePatterns : Seq[TreePattern] = Seq(IF) + override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { TypeCheckResult.TypeCheckFailure( @@ -139,6 +142,8 @@ case class CaseWhen( override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue + final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = super.legacyWithNewChildren(newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 010b9b0fae..e69bf2e5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2378,15 +2378,15 @@ object DatePart { Literal(null, DoubleType) } else { val fieldStr = fieldEval.asInstanceOf[UTF8String].toString - val analysisException = QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError( - fieldStr, source) - if (source.dataType == CalendarIntervalType) { - ExtractIntervalPart.parseExtractField( - fieldStr, - source, - throw analysisException) - } else { - DatePart.parseExtractField(fieldStr, source, throw analysisException) + + def analysisException = + throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source) + + source.dataType match { + case YearMonthIntervalType | DayTimeIntervalType | CalendarIntervalType => + ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException) + case _ => + DatePart.parseExtractField(fieldStr, source, analysisException) } } } @@ -2414,6 +2414,10 @@ object DatePart { 5 > SELECT _FUNC_('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); 30.001001 + > SELECT _FUNC_('MONTH', INTERVAL '2021-11' YEAR TO MONTH); + 11 + > SELECT _FUNC_('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); + 55 """, note = """ The _FUNC_ function is equivalent to the SQL-standard function `EXTRACT(field FROM source)` @@ -2479,6 +2483,10 @@ case class DatePart(field: Expression, source: Expression, child: Expression) 5 > SELECT _FUNC_(seconds FROM interval 5 hours 30 seconds 1 milliseconds 1 microseconds); 30.001001 + > SELECT _FUNC_(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH); + 11 + > SELECT _FUNC_(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND); + 55 """, note = """ The _FUNC_ function is equivalent to `date_part(field, source)`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 808c8222d2..aff1806582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -214,9 +214,9 @@ case class Grouping(child: Expression) extends Expression with Unevaluable Examples: > SELECT name, _FUNC_(), sum(age), avg(height) FROM VALUES (2, 'Alice', 165), (5, 'Bob', 180) people(age, name, height) GROUP BY cube(name, height); Alice 0 2 165.0 + Bob 0 5 180.0 Alice 1 2 165.0 NULL 3 7 172.5 - Bob 0 5 180.0 Bob 1 5 180.0 NULL 2 2 165.0 NULL 2 5 180.0 @@ -277,22 +277,3 @@ object GroupingAnalytics { } } } - -/** - * A reference to an grouping expression in [[Aggregate]] node. - * - * @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression - * refers to. - * @param dataType The [[DataType]] of the referenced grouping expression. - * @param nullable True if null is a valid value for the referenced grouping expression. - */ -case class GroupingExprRef( - ordinal: Int, - dataType: DataType, - nullable: Boolean) - extends LeafExpression with Unevaluable { - - override def stringArgs: Iterator[Any] = { - Iterator(ordinal) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index b34bcaf5ce..94ca6cc65d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -25,78 +25,131 @@ import com.google.common.math.{DoubleMath, IntMath, LongMath} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -abstract class ExtractIntervalPart( - child: Expression, +abstract class ExtractIntervalPart[T]( val dataType: DataType, - func: CalendarInterval => Any, - funcName: String) - extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable { - - override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType) - - override protected def nullSafeEval(interval: Any): Any = { - func(interval.asInstanceOf[CalendarInterval]) - } - + func: T => Any, + funcName: String) extends UnaryExpression with NullIntolerant with Serializable { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iu = IntervalUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$iu.$funcName($c)") } + + override protected def nullSafeEval(interval: Any): Any = { + func(interval.asInstanceOf[T]) + } } case class ExtractIntervalYears(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") { + extends ExtractIntervalPart[CalendarInterval](IntegerType, getYears, "getYears") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalYears = copy(child = newChild) } case class ExtractIntervalMonths(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") { + extends ExtractIntervalPart[CalendarInterval](ByteType, getMonths, "getMonths") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMonths = copy(child = newChild) } case class ExtractIntervalDays(child: Expression) - extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") { + extends ExtractIntervalPart[CalendarInterval](IntegerType, getDays, "getDays") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalDays = copy(child = newChild) } case class ExtractIntervalHours(child: Expression) - extends ExtractIntervalPart(child, LongType, getHours, "getHours") { + extends ExtractIntervalPart[CalendarInterval](ByteType, getHours, "getHours") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalHours = copy(child = newChild) } case class ExtractIntervalMinutes(child: Expression) - extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") { + extends ExtractIntervalPart[CalendarInterval](ByteType, getMinutes, "getMinutes") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalMinutes = copy(child = newChild) } case class ExtractIntervalSeconds(child: Expression) - extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, "getSeconds") { + extends ExtractIntervalPart[CalendarInterval](DecimalType(8, 6), getSeconds, "getSeconds") { override protected def withNewChildInternal(newChild: Expression): ExtractIntervalSeconds = copy(child = newChild) } +case class ExtractANSIIntervalYears(child: Expression) + extends ExtractIntervalPart[Int](IntegerType, getYears, "getYears") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalYears = + copy(child = newChild) +} + +case class ExtractANSIIntervalMonths(child: Expression) + extends ExtractIntervalPart[Int](ByteType, getMonths, "getMonths") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalMonths = + copy(child = newChild) +} + +case class ExtractANSIIntervalDays(child: Expression) + extends ExtractIntervalPart[Long](IntegerType, getDays, "getDays") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalDays = { + copy(child = newChild) + } +} + +case class ExtractANSIIntervalHours(child: Expression) + extends ExtractIntervalPart[Long](ByteType, getHours, "getHours") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalHours = + copy(child = newChild) +} + +case class ExtractANSIIntervalMinutes(child: Expression) + extends ExtractIntervalPart[Long](ByteType, getMinutes, "getMinutes") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalMinutes = + copy(child = newChild) +} + +case class ExtractANSIIntervalSeconds(child: Expression) + extends ExtractIntervalPart[Long](DecimalType(8, 6), getSeconds, "getSeconds") { + override protected def withNewChildInternal(newChild: Expression): ExtractANSIIntervalSeconds = + copy(child = newChild) +} + object ExtractIntervalPart { def parseExtractField( extractField: String, source: Expression, - errorHandleFunc: => Nothing): Expression = extractField.toUpperCase(Locale.ROOT) match { - case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => ExtractIntervalYears(source) - case "MONTH" | "MON" | "MONS" | "MONTHS" => ExtractIntervalMonths(source) - case "DAY" | "D" | "DAYS" => ExtractIntervalDays(source) - case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => ExtractIntervalHours(source) - case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => ExtractIntervalMinutes(source) - case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => ExtractIntervalSeconds(source) - case _ => errorHandleFunc + errorHandleFunc: => Nothing): Expression = { + (extractField.toUpperCase(Locale.ROOT), source.dataType) match { + case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType) => + ExtractANSIIntervalYears(source) + case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) => + ExtractIntervalYears(source) + case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType) => + ExtractANSIIntervalMonths(source) + case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) => + ExtractIntervalMonths(source) + case ("DAY" | "D" | "DAYS", DayTimeIntervalType) => + ExtractANSIIntervalDays(source) + case ("DAY" | "D" | "DAYS", CalendarIntervalType) => + ExtractIntervalDays(source) + case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType) => + ExtractANSIIntervalHours(source) + case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) => + ExtractIntervalHours(source) + case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType) => + ExtractANSIIntervalMinutes(source) + case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) => + ExtractIntervalMinutes(source) + case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType) => + ExtractANSIIntervalSeconds(source) + case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) => + ExtractIntervalSeconds(source) + case _ => errorHandleFunc + } } } @@ -391,11 +444,22 @@ case class MultiplyDTInterval( copy(interval = newLeft, num = newRight) } +trait IntervalDivide { + def checkDivideOverflow(value: Any, minValue: Any, num: Expression, numValue: Any): Unit = { + if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { + if (numValue.asInstanceOf[Number].longValue() == -1) { + throw QueryExecutionErrors.overflowInIntegralDivideError() + } + } + } +} + // Divide an year-month interval by a numeric case class DivideYMInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide + with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -418,20 +482,31 @@ case class DivideYMInterval( } override def nullSafeEval(interval: Any, num: Any): Any = { + checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num) evalFunc(interval.asInstanceOf[Int], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match { - case LongType => - val math = classOf[LongMath].getName + case t: IntegralType => + val math = t match { + case LongType => classOf[LongMath].getName + case _ => classOf[IntMath].getName + } val javaType = CodeGenerator.javaType(dataType) - defineCodeGen(ctx, ev, (m, n) => + val months = left.genCode(ctx) + val num = right.genCode(ctx) + val checkIntegralDivideOverflow = + s""" + |if (${months.value} == ${Int.MinValue} && ${num.value} == -1) + | throw QueryExecutionErrors.overflowInIntegralDivideError(); + |""".stripMargin + nullSafeCodeGen(ctx, ev, (m, n) => // Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`. // Casting to `Int` is safe here. - s"($javaType)($math.divide($m, $n, java.math.RoundingMode.HALF_UP))") - case _: IntegralType => - val math = classOf[IntMath].getName - defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, java.math.RoundingMode.HALF_UP)") + s""" + |$checkIntegralDivideOverflow + |${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP); + """.stripMargin) case _: DecimalType => defineCodeGen(ctx, ev, (m, n) => s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + @@ -454,7 +529,8 @@ case class DivideYMInterval( case class DivideDTInterval( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with IntervalDivide + with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -473,13 +549,25 @@ case class DivideDTInterval( } override def nullSafeEval(interval: Any, num: Any): Any = { + checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num) evalFunc(interval.asInstanceOf[Long], num) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = right.dataType match { case _: IntegralType => val math = classOf[LongMath].getName - defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, java.math.RoundingMode.HALF_UP)") + val micros = left.genCode(ctx) + val num = right.genCode(ctx) + val checkIntegralDivideOverflow = + s""" + |if (${micros.value} == ${Long.MinValue}L && ${num.value} == -1L) + | throw QueryExecutionErrors.overflowInIntegralDivideError(); + |""".stripMargin + nullSafeCodeGen(ctx, ev, (m, n) => + s""" + |$checkIntegralDivideOverflow + |${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP); + """.stripMargin) case _: DecimalType => defineCodeGen(ctx, ev, (m, n) => s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2c2df6bf43..d4a02c7fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression) case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false + final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK) + override def eval(input: InternalRow): Any = { child.eval(input) == null } @@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false + final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK) + override def eval(input: InternalRow): Any = { child.eval(input) != null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5ae0cef7b4..a17ac203ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) override def foldable: Boolean = false override def nullable: Boolean = false + final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK) + override def flatArguments: Iterator[Any] = Iterator(child) private val errMsg = "Null value appeared in non-nullable field:" + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d78c726753..4885f7761f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -309,6 +309,8 @@ case class Not(child: Expression) override def inputTypes: Seq[DataType] = Seq(BooleanType) + final override val nodePatterns: Seq[TreePattern] = Seq(NOT) + // +---------+-----------+ // | CHILD | NOT CHILD | // +---------+-----------+ @@ -342,6 +344,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) values.head } + final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY) override def checkInputDataTypes(): TypeCheckResult = { if (values.length != query.childOutputs.length) { @@ -434,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) - override val nodePatterns: Seq[TreePattern] = Seq(IN) + final override val nodePatterns: Seq[TreePattern] = Seq(IN) override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" @@ -547,6 +550,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def nullable: Boolean = child.nullable || hasNull + final override val nodePatterns: Seq[TreePattern] = Seq(INSET) + protected override def nullSafeEval(value: Any): Any = { if (set.contains(value)) { true @@ -663,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with override def sqlOperator: String = "AND" + final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR) + // +---------+---------+---------+---------+ // | AND | TRUE | FALSE | UNKNOWN | // +---------+---------+---------+---------+ @@ -749,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P override def sqlOperator: String = "OR" + final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR) + // +---------+---------+---------+---------+ // | OR | TRUE | FALSE | UNKNOWN | // +---------+---------+---------+---------+ @@ -820,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { // finitely enumerable. The allowable types are checked below by checkInputDataTypes. override def inputType: AbstractDataType = AnyDataType + final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON) + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 13d00faea3..57d7d76268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY) + override def toString: String = escapeChar match { case '\\' => s"$left LIKE $right" case c => s"$left LIKE $right ESCAPE '$c'" @@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase override def nullable: Boolean = true + final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY) + protected lazy val hasNull: Boolean = patterns.contains(null) protected lazy val cache = patterns.filterNot(_ == null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3d5f812af9..5956c3e882 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -406,6 +407,8 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase // scalastyle:on caselocale + final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } @@ -432,6 +435,8 @@ case class Lower(child: Expression) override def convert(v: UTF8String): UTF8String = v.toLowerCase // scalastyle:on caselocale + final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 2bedf84271..ac939bf6d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY, + PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet @@ -38,6 +40,11 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { bits } + final override val nodePatterns: Seq[TreePattern] = Seq(PLAN_EXPRESSION) ++ nodePatternsInternal + + // Subclasses can override this function to provide more TreePatterns. + def nodePatternsInternal(): Seq[TreePattern] = Seq() + /** The id of the subquery expression. */ def exprId: ExprId @@ -247,6 +254,8 @@ case class ScalarSubquery( override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY) } object ScalarSubquery { @@ -295,6 +304,8 @@ case class ListQuery( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY) } /** @@ -340,4 +351,6 @@ case class Exists( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 916f4eae7e..fe9c41e387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -101,7 +101,10 @@ case class WindowSpecDefinition( private def isValidFrameType(ft: DataType): Boolean = (orderSpec.head.dataType, ft) match { case (DateType, IntegerType) => true + case (DateType, YearMonthIntervalType) => true case (TimestampType, CalendarIntervalType) => true + case (TimestampType, YearMonthIntervalType) => true + case (TimestampType, DayTimeIntervalType) => true case (a, b) => a == b } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 8f1548a978..0ff11ca49f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** @@ -26,6 +26,15 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object SimplifyExtractValueOps extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 0be2792bfd..5b12667f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -231,6 +231,27 @@ object NestedColumnAliasing { * of it. */ object GeneratorNestedColumnAliasing { + // Partitions `attrToAliases` based on whether the attribute is in Generator's output. + private def aliasesOnGeneratorOutput( + attrToAliases: Map[ExprId, Seq[Alias]], + generatorOutput: Seq[Attribute]) = { + val generatorOutputExprId = generatorOutput.map(_.exprId) + attrToAliases.partition { k => + generatorOutputExprId.contains(k._1) + } + } + + // Partitions `nestedFieldToAlias` based on whether the attribute of nested field extractor + // is in Generator's output. + private def nestedFieldOnGeneratorOutput( + nestedFieldToAlias: Map[ExtractValue, Alias], + generatorOutput: Seq[Attribute]) = { + val generatorOutputSet = AttributeSet(generatorOutput) + nestedFieldToAlias.partition { pair => + pair._1.references.subsetOf(generatorOutputSet) + } + } + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we // need to prune nested columns through Project and under Generate. The difference is @@ -241,12 +262,81 @@ object GeneratorNestedColumnAliasing { // On top on `Generate`, a `Project` that might have nested column accessors. // We try to get alias maps for both project list and generator's children expressions. val exprsToPrune = projectList ++ g.generator.children - NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map { + NestedColumnAliasing.getAliasSubMap(exprsToPrune).map { case (nestedFieldToAlias, attrToAliases) => + val (nestedFieldsOnGenerator, nestedFieldsNotOnGenerator) = + nestedFieldOnGeneratorOutput(nestedFieldToAlias, g.qualifiedGeneratorOutput) + val (attrToAliasesOnGenerator, attrToAliasesNotOnGenerator) = + aliasesOnGeneratorOutput(attrToAliases, g.qualifiedGeneratorOutput) + + // Push nested column accessors through `Generator`. // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newChild = - NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) - Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + val newChild = NestedColumnAliasing.replaceWithAliases(g, + nestedFieldsNotOnGenerator, attrToAliasesNotOnGenerator) + val pushedThrough = Project(NestedColumnAliasing + .getNewProjectList(projectList, nestedFieldsNotOnGenerator), newChild) + + // If the generator output is `ArrayType`, we cannot push through the extractor. + // It is because we don't allow field extractor on two-level array, + // i.e., attr.field when attr is a ArrayType(ArrayType(...)). + // Similarily, we also cannot push through if the child of generator is `MapType`. + g.generator.children.head.dataType match { + case _: MapType => return Some(pushedThrough) + case ArrayType(_: ArrayType, _) => return Some(pushedThrough) + case _ => + } + + // Pruning on `Generator`'s output. We only process single field case. + // For multiple field case, we cannot directly move field extractor into + // the generator expression. A workaround is to re-construct array of struct + // from multiple fields. But it will be more complicated and may not worth. + // TODO(SPARK-34956): support multiple fields. + if (nestedFieldsOnGenerator.size > 1 || nestedFieldsOnGenerator.isEmpty) { + pushedThrough + } else { + // Only one nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") + pushedThrough match { + case p @ Project(_, newG: Generate) => + // Replace the child expression of `ExplodeBase` generator with + // nested column accessor. + // E.g., df.select(explode($"items").as("item")).select($"item.a") => + // df.select(explode($"items.a").as("item.a")) + val rewrittenG = newG.transformExpressions { + case e: ExplodeBase => + val extractor = nestedFieldsOnGenerator.head._1.transformUp { + case _: Attribute => + e.child + case g: GetStructField => + ExtractValue(g.child, Literal(g.extractFieldName), SQLConf.get.resolver) + } + e.withNewChildren(Seq(extractor)) + } + + // As we change the child of the generator, its output data type must be updated. + val updatedGeneratorOutput = rewrittenG.generatorOutput + .zip(rewrittenG.generator.elementSchema.toAttributes) + .map { case (oldAttr, newAttr) => + newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) + } + assert(updatedGeneratorOutput.length == rewrittenG.generatorOutput.length, + "Updated generator output must have the same length " + + "with original generator output.") + val updatedGenerate = rewrittenG.copy(generatorOutput = updatedGeneratorOutput) + + // Replace nested column accessor with generator output. + p.withNewChildren(Seq(updatedGenerate)).transformExpressions { + case f: ExtractValue if nestedFieldsOnGenerator.contains(f) => + updatedGenerate.output + .find(a => attrToAliasesOnGenerator.contains(a.exprId)) + .getOrElse(f) + } + + case other => + // We should not reach here. + throw new IllegalStateException(s"Unreasonable plan after optimization: $other") + } + } } case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5343fce07c..16e3e43356 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -118,8 +119,7 @@ abstract class Optimizer(catalogManager: CatalogManager) OptimizeUpdateFields, SimplifyExtractValueOps, OptimizeCsvJsonExprs, - CombineConcats, - UpdateGroupingExprRefNullability) ++ + CombineConcats) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { @@ -148,7 +148,6 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, - EnforceGroupingReferencesInAggregates, ComputeCurrentTime, GetCurrentDatabaseAndCatalog(catalogManager)) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -268,9 +267,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceUpdateFieldsExpression.ruleName :: - EnforceGroupingReferencesInAggregates.ruleName :: - UpdateGroupingExprRefNullability.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. @@ -283,7 +280,8 @@ abstract class Optimizer(catalogManager: CatalogManager) case other => other } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION), ruleId) { case s: SubqueryExpression => val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) // At this point we have an optimized subquery plan that we are going to attach @@ -510,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => val aliasMap = getAliasMap(lower) - val newAggregate = Aggregate.withGroupingRefs( + val newAggregate = upper.copy( child = lower.child, groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), aggregateExpressions = upper.aggregateExpressions.map( @@ -526,19 +524,23 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { } private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = { - val upperHasNoAggregateExpressions = - !upper.aggregateExpressions.exists(AggregateExpression.containsAggregate) + val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate) lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet( lower .aggregateExpressions .filter(_.deterministic) - .filterNot(AggregateExpression.containsAggregate) + .filter(!isAggregate(_)) .map(_.toAttribute) )) upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg } + + private def isAggregate(expr: Expression): Boolean = { + expr.find(e => e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupedAggPandasUDF(e)).isDefined + } } /** @@ -1976,18 +1978,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) if (newGrouping.nonEmpty) { - val droppedGroupsBefore = - grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray - - val newAggregateExpressions = - a.aggregateExpressions.map(_.transform { - case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => - g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) - }.asInstanceOf[NamedExpression]) - - a.copy( - groupingExpressions = newGrouping, - aggregateExpressions = newAggregateExpressions) + a.copy(groupingExpressions = newGrouping) } else { // All grouping expressions are literals. We should not drop them all, because this can // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We @@ -2008,25 +1999,7 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { if (newGrouping.size == grouping.size) { a } else { - var i = 0 - val droppedGroupsBefore = grouping.scanLeft(0)((n, e) => - n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) { - i += 1 - 0 - } else { - 1 - }) - ).toArray - - val newAggregateExpressions = - a.aggregateExpressions.map(_.transform { - case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 => - g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal)) - }.asInstanceOf[NamedExpression]) - - a.copy( - groupingExpressions = newGrouping, - aggregateExpressions = newAggregateExpressions) + a.copy(groupingExpressions = newGrouping) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 3c48742b87..3de19afa91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, DeleteFromTable, Filter, InsertAction, InsertStarAction, Join, LogicalPlan, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_LITERAL, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.catalyst.trees.TreePattern.{INSET, NULL_LITERAL, TRUE_OR_FALSE_LITERAL} import org.apache.spark.sql.types.BooleanType import org.apache.spark.util.Utils @@ -51,7 +51,7 @@ import org.apache.spark.util.Utils object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL), ruleId) { + _.containsAnyPattern(NULL_LITERAL, TRUE_OR_FALSE_LITERAL, INSET), ruleId) { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) case j @ Join(_, _, _, Some(cond), _) => j.copy(condition = Some(replaceNullWithFalse(cond))) case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index 8964a2776b..be39c3f10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -49,28 +49,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] { val values = withFields.map(_.valExpr) val newNames = mutable.ArrayBuffer.empty[String] - val newValues = mutable.ArrayBuffer.empty[Expression] + val newValues = mutable.HashMap.empty[String, Expression] + // Used to remember the casing of the last instance + val nameMap = mutable.HashMap.empty[String, String] - if (caseSensitive) { - names.zip(values).reverse.foreach { case (name, value) => - if (!newNames.contains(name)) { - newNames += name - newValues += value - } - } - } else { - val nameSet = mutable.HashSet.empty[String] - names.zip(values).reverse.foreach { case (name, value) => - val lowercaseName = name.toLowerCase(Locale.ROOT) - if (!nameSet.contains(lowercaseName)) { - newNames += name - newValues += value - nameSet += lowercaseName - } + names.zip(values).foreach { case (name, value) => + val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + if (nameMap.contains(normalizedName)) { + newValues += normalizedName -> value + } else { + newNames += normalizedName + newValues += normalizedName -> value } + nameMap += normalizedName -> name } - val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2)) + val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n))) UpdateFields(structExpr, newWithFields.toSeq) case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 65466c5ec0..e9752e046a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreePattern.IN +import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -50,8 +51,9 @@ object ConstantFolding extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(AlwaysProcess.fn, ruleId) { + case q: LogicalPlan => q.transformExpressionsDownWithPruning( + AlwaysProcess.fn, ruleId) { // Skip redundant folding of literals. This rule is technically not necessary. Placing this // here avoids running the next rule for Literal values, which would create a new Literal // object and running eval unnecessarily. @@ -83,7 +85,8 @@ object ConstantFolding extends Rule[LogicalPlan] { * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAllPatterns(LITERAL, FILTER), ruleId) { case f: Filter => val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) if (newCondition.isDefined) { @@ -210,14 +213,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { case _ => ExpressionSet(Seq.empty) } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(BINARY_ARITHMETIC), ruleId) { case q: LogicalPlan => // We have to respect aggregate expressions which exists in grouping expressions when plan // is an Aggregate operator, otherwise the optimized expression could not be derived from // grouping expressions. // TODO: do not reorder consecutive `Add`s or `Multiply`s with different `failOnError` flags val groupingExpressionSet = collectGroupingExpressions(q) - q transformExpressionsDown { + q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) { case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] => val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) if (foldables.size > 1) { @@ -286,8 +290,10 @@ object OptimizeIn extends Rule[LogicalPlan] { * 4. Removes `Not` operator. */ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(AND_OR, NOT), ruleId) { + case q: LogicalPlan => q.transformExpressionsUpWithPruning( + _.containsAnyPattern(AND_OR, NOT), ruleId) { case TrueLiteral And e => e case e And TrueLiteral => e case FalseLiteral Or e => e @@ -460,7 +466,8 @@ object SimplifyBinaryComparison } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(BINARY_COMPARISON), ruleId) { case l: LogicalPlan => lazy val notNullExpressions = ExpressionSet(l match { case Filter(fc, _) => @@ -470,7 +477,7 @@ object SimplifyBinaryComparison case _ => Seq.empty }) - l transformExpressionsUp { + l.transformExpressionsUpWithPruning(_.containsPattern(BINARY_COMPARISON)) { // True with equality case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral case a EqualTo b if canSimplifyComparison(a, b, notNullExpressions) => TrueLiteral @@ -496,7 +503,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(IF, CASE_WHEN), ruleId) { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue @@ -584,7 +592,7 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { true case _: CastBase => true case _: GetDateField | _: LastDay => true - case _: ExtractIntervalPart => true + case _: ExtractIntervalPart[_] => true case _: ArraySetLike => true case _: ExtractValue => true case _ => false @@ -601,8 +609,10 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(CASE_WHEN, IF), ruleId) { + case q: LogicalPlan => q.transformExpressionsUpWithPruning( + _.containsAnyPattern(CASE_WHEN, IF), ruleId) { case u @ UnaryExpression(i @ If(_, trueValue, falseValue)) if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( @@ -713,7 +723,8 @@ object LikeSimplification extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(LIKE_FAMLIY), ruleId) { case l @ Like(input, Literal(pattern, StringType), escapeChar) => if (pattern == null) { // If pattern is null, return null value directly, since "col like null" == null. @@ -740,8 +751,12 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT) + || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) { + case q: LogicalPlan => q.transformExpressionsUpWithPruning( + t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT) + || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) => @@ -917,7 +932,8 @@ object FoldablePropagation extends Rule[LogicalPlan] { * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ object SimplifyCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(CAST), ruleId) { case Cast(e, dataType, _) if e.dataType == dataType => e case c @ Cast(e, dataType, _) => (e.dataType, dataType) match { case (ArrayType(from, false), ArrayType(to, true)) if from == to => e @@ -933,7 +949,8 @@ object SimplifyCasts extends Rule[LogicalPlan] { * Removes nodes that are not necessary. */ object RemoveDispensableExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(UNARY_POSITIVE), ruleId) { case UnaryPositive(child) => child } } @@ -944,8 +961,10 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(UPPER_OR_LOWER), ruleId) { + case q: LogicalPlan => q.transformExpressionsUpWithPruning( + _.containsPattern(UPPER_OR_LOWER), ruleId) { case Upper(Upper(child)) => Upper(child) case Upper(Lower(child)) => Upper(child) case Lower(Upper(child)) => Lower(child) @@ -986,7 +1005,8 @@ object CombineConcats extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(CONCAT), ruleId) { case concat: Concat if hasNestedConcats(concat) => flattenConcats(concat) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9381796d3d..ca3aca54f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, + LIST_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -94,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + t => t.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY) && t.containsPattern(FILTER)) { case Filter(condition, child) if SubqueryExpression.hasInOrCorrelatedExistsSubquery(condition) => val (withSubquery, withoutSubquery) = @@ -164,7 +167,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { plan: LogicalPlan): (Option[Expression], LogicalPlan) = { var newPlan = plan val newExprs = exprs.map { e => - e transformDown { + e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = @@ -303,7 +306,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } } - plan transformExpressions { + plan.transformExpressionsWithPruning(_.containsAnyPattern( + SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) @@ -319,7 +323,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. @@ -341,7 +346,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe private def extractCorrelatedScalarSubqueries[E <: Expression]( expression: E, subqueries: ArrayBuffer[ScalarSubquery]): E = { - val newExpression = expression transform { + val newExpression = expression.transformWithPruning(_.containsPattern(SCALAR_SUBQUERY)) { case s: ScalarSubquery if s.children.nonEmpty => subqueries += s s.plan.output.head @@ -628,10 +633,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe * subqueries. */ def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { - case a @ Aggregate(grouping, _, child) => + case a @ Aggregate(grouping, expressions, child) => val subqueries = ArrayBuffer.empty[ScalarSubquery] - val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs - .map(extractCorrelatedScalarSubqueries(_, subqueries)) + val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { // We currently only allow correlated subqueries in an aggregate if they are part of the // grouping expressions. As a result we need to replace all the scalar subqueries in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8c8e8c6a8a..06bbb984d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -993,26 +993,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg .map(groupByExpr => { val groupingAnalytics = groupByExpr.groupingAnalytics if (groupingAnalytics != null) { - val groupingSets = groupingAnalytics.groupingSet.asScala - .map(_.expression.asScala.map(e => expression(e)).toSeq) - if (groupingAnalytics.CUBE != null) { - // CUBE(A, B, (A, B), ()) is not supported. - if (groupingSets.exists(_.isEmpty)) { - throw new ParseException("Empty set in CUBE grouping sets is not supported.", - groupingAnalytics) - } - Cube(groupingSets.toSeq) - } else if (groupingAnalytics.ROLLUP != null) { - // ROLLUP(A, B, (A, B), ()) is not supported. - if (groupingSets.exists(_.isEmpty)) { - throw new ParseException("Empty set in ROLLUP grouping sets is not supported.", - groupingAnalytics) - } - Rollup(groupingSets.toSeq) - } else { - assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null) - GroupingSets(groupingSets.toSeq) - } + visitGroupingAnalytics(groupingAnalytics) } else { expression(groupByExpr.expression) } @@ -1021,6 +1002,36 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } } + override def visitGroupingAnalytics( + groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = { + val groupingSets = groupingAnalytics.groupingSet.asScala + .map(_.expression.asScala.map(e => expression(e)).toSeq) + if (groupingAnalytics.CUBE != null) { + // CUBE(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics) + } + Cube(groupingSets.toSeq) + } else if (groupingAnalytics.ROLLUP != null) { + // ROLLUP(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw QueryParsingErrors.invalidGroupingSetError("ROLLUP", groupingAnalytics) + } + Rollup(groupingSets.toSeq) + } else { + assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null) + val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr => + val groupingAnalytics = expr.groupingAnalytics() + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs + } else { + Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq) + } + } + GroupingSets(groupingSets.toSeq) + } + } + /** * Add [[UnresolvedHint]]s to a logical plan. */ @@ -2395,13 +2406,22 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = { withOrigin(ctx) { - val value = Option(ctx.intervalValue.STRING).map(string).getOrElse { + val value = Option(ctx.intervalValue.STRING).map(string).map { interval => + if (ctx.intervalValue().MINUS() == null) { + interval + } else { + interval.startsWith("-") match { + case true => interval.replaceFirst("-", "") + case false => s"-$interval" + } + } + }.getOrElse { throw QueryParsingErrors.invalidFromToUnitValueError(ctx.intervalValue) } try { val from = ctx.from.getText.toLowerCase(Locale.ROOT) val to = ctx.to.getText.toLowerCase(Locale.ROOT) - val interval = (from, to) match { + (from, to) match { case ("year", "month") => IntervalUtils.fromYearMonthString(value) case ("day", "hour") => @@ -2419,9 +2439,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case _ => throw QueryParsingErrors.fromToIntervalUnsupportedError(from, to, ctx) } - Option(ctx.intervalValue.MINUS) - .map(_ => IntervalUtils.negateExact(interval)) - .getOrElse(interval) } catch { // Handle Exceptions thrown by CalendarInterval case e: IllegalArgumentException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a96674fe97..c22a874779 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -287,7 +287,7 @@ object PhysicalAggregation { (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { - case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) => + case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll // build a set of semantically distinct aggregate expressions and re-write expressions so @@ -297,9 +297,11 @@ object PhysicalAggregation { val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. - case a - if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) => - a + case agg: AggregateExpression + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupedAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -320,7 +322,7 @@ object PhysicalAggregation { // which takes the grouping columns and final aggregate result buffer as input. // Thus, we must re-write the result expressions so that their attributes match up with // the attributes of the final result projection's input row: - val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr => + val rewrittenResultExpressions = resultExpressions.map { expr => expr.transformDown { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 49e3e3c9ee..0f5bc7e1f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical -import scala.collection.mutable - import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, TypeCoercion, TypeCoercionBase} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} @@ -28,9 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.{ - INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern -} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -166,6 +162,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows + final override val nodePatterns: Seq[TreePattern] = Seq(FILTER) + override protected lazy val validConstraints: ExpressionSet = { val predicates = splitConjunctivePredicates(condition) .filterNot(SubqueryExpression.hasCorrelatedSubquery) @@ -781,23 +779,14 @@ case class Range( /** * This is a Group by operator with the aggregate functions and projections. * - * @param groupingExpressions Expressions for grouping keys. - * @param aggregateExpressions Expressions for a project list, which can contain - * [[AggregateExpression]]s and [[GroupingExprRef]]s. - * @param child The child of the aggregate node. + * @param groupingExpressions expressions for grouping keys + * @param aggregateExpressions expressions for a project list, which could contain + * [[AggregateExpression]]s. * - * Expressions without aggregate functions in [[aggregateExpressions]] can contain - * [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These - * references ensure that optimization rules don't change the aggregate expressions to invalid ones - * that no longer refer to any grouping expressions and also simplify the expression transformations - * on the node (need to transform the expression only once). - * - * For example, in the following query Spark shouldn't optimize the aggregate expression - * `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`: - * SELECT not(c IS NULL) - * FROM t - * GROUP BY c IS NULL - * Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`. + * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before + * separating projection from grouping and aggregate, we should avoid expression-level optimization + * on aggregateExpressions, which could reference an expression in groupingExpressions. + * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]] */ case class Aggregate( groupingExpressions: Seq[Expression], @@ -824,21 +813,8 @@ case class Aggregate( } } - private def expandGroupingReferences(e: Expression): Expression = { - e match { - case _ if AggregateExpression.isAggregate(e) => e - case g: GroupingExprRef => groupingExpressions(g.ordinal) - case _ => e.mapChildren(expandGroupingReferences) - } - } - - lazy val aggregateExpressionsWithoutGroupingRefs = { - aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression]) - } - override lazy val validConstraints: ExpressionSet = { - val nonAgg = aggregateExpressionsWithoutGroupingRefs. - filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) + val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) getAllValidConstraints(nonAgg) } @@ -846,51 +822,6 @@ case class Aggregate( copy(child = newChild) } -object Aggregate { - private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = { - val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)] - var i = 0 - groupingExpressions.foreach { ge => - if (!ge.foldable && ge.children.nonEmpty && - !complexGroupingExpressions.contains(ge.canonicalized)) { - complexGroupingExpressions += ge.canonicalized -> (ge, i) - } - i += 1 - } - complexGroupingExpressions - } - - private def insertGroupingReferences( - aggregateExpressions: Seq[NamedExpression], - groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = { - def insertGroupingExprRefs(e: Expression): Expression = { - e match { - case _ if AggregateExpression.isAggregate(e) => e - case _ if groupingExpressions.contains(e.canonicalized) => - val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized) - GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable) - case _ => e.mapChildren(insertGroupingExprRefs) - } - } - - aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression]) - } - - def withGroupingRefs( - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: LogicalPlan): Aggregate = { - val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions) - val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) { - insertGroupingReferences(aggregateExpressions, complexGroupingExpressions) - } else { - aggregateExpressions - } - - new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child) - } -} - case class Window( windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index d745f50604..1c997d3740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -42,17 +42,33 @@ object RuleIdCollection { // Catalyst Analyzer rules "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: + "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" :: // Catalyst Optimizer rules + "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" :: + "org.apache.spark.sql.catalyst.optimizer.CombineConcats" :: + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" :: + "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" :: "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" :: "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" :: + "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: + "org.apache.spark.sql.catalyst.optimizer.NullPropagation" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: + "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: + "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" :: "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" :: + "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" :: + "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" :: "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" :: - "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: Nil + "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: + "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" :: + "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" :: + "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" :: + "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: Nil } // Maps rule names to ids. Rule ids are continuous natural numbers starting from 0. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index bb09b9ddda..faf736d9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -23,15 +23,36 @@ object TreePattern extends Enumeration { // Enum Ids start from 0. // Expression patterns (alphabetically ordered) - val ATTRIBUTE_REFERENCE = Value(0) - val EXPRESSION_WITH_RANDOM_SEED = Value + val AND_OR: Value = Value(0) + val ATTRIBUTE_REFERENCE: Value = Value + val BINARY_ARITHMETIC: Value = Value + val BINARY_COMPARISON: Value = Value + val CASE_WHEN: Value = Value + val CAST: Value = Value + val CONCAT: Value = Value + val COUNT: Value = Value + val DYNAMIC_PRUNING_SUBQUERY: Value = Value + val EXISTS_SUBQUERY = Value + val EXPRESSION_WITH_RANDOM_SEED: Value = Value + val IF: Value = Value val IN: Value = Value + val IN_SUBQUERY: Value = Value + val INSET: Value = Value + val LIKE_FAMLIY: Value = Value + val LIST_SUBQUERY: Value = Value val LITERAL: Value = Value + val NOT: Value = Value + val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value + val PLAN_EXPRESSION: Value = Value + val SCALAR_SUBQUERY: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value + val UNARY_POSITIVE: Value = Value + val UPPER_OR_LOWER: Value = Value // Logical plan patterns (alphabetically ordered) + val FILTER: Value = Value val INNER_LIKE_JOIN: Value = Value val JOIN: Value = Value val LEFT_SEMI_OR_ANTI_JOIN: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index e52d3c8817..c9bc579ceb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -54,31 +54,39 @@ object IntervalUtils { } import IntervalUnit._ - def getYears(interval: CalendarInterval): Int = { - interval.months / MONTHS_PER_YEAR - } + def getYears(months: Int): Int = months / MONTHS_PER_YEAR - def getMonths(interval: CalendarInterval): Byte = { - (interval.months % MONTHS_PER_YEAR).toByte - } + def getYears(interval: CalendarInterval): Int = getYears(interval.months) + + def getMonths(months: Int): Byte = (months % MONTHS_PER_YEAR).toByte + + def getMonths(interval: CalendarInterval): Byte = getMonths(interval.months) + + def getDays(microseconds: Long): Int = (microseconds / MICROS_PER_DAY).toInt def getDays(interval: CalendarInterval): Int = { - val daysInMicroseconds = (interval.microseconds / MICROS_PER_DAY).toInt + val daysInMicroseconds = getDays(interval.microseconds) Math.addExact(interval.days, daysInMicroseconds) } - def getHours(interval: CalendarInterval): Long = { - (interval.microseconds % MICROS_PER_DAY) / MICROS_PER_HOUR + def getHours(microseconds: Long): Byte = { + ((microseconds % MICROS_PER_DAY) / MICROS_PER_HOUR).toByte } - def getMinutes(interval: CalendarInterval): Byte = { - ((interval.microseconds % MICROS_PER_HOUR) / MICROS_PER_MINUTE).toByte + def getHours(interval: CalendarInterval): Byte = getHours(interval.microseconds) + + def getMinutes(microseconds: Long): Byte = { + ((microseconds % MICROS_PER_HOUR) / MICROS_PER_MINUTE).toByte } - def getSeconds(interval: CalendarInterval): Decimal = { - Decimal(interval.microseconds % MICROS_PER_MINUTE, 8, 6) + def getMinutes(interval: CalendarInterval): Byte = getMinutes(interval.microseconds) + + def getSeconds(microseconds: Long): Decimal = { + Decimal(microseconds % MICROS_PER_MINUTE, 8, 6) } + def getSeconds(interval: CalendarInterval): Decimal = getSeconds(interval.microseconds) + private def toLongWithRange( fieldName: IntervalUnit, s: String, @@ -100,12 +108,11 @@ object IntervalUtils { */ def fromYearMonthString(input: String): CalendarInterval = { require(input != null, "Interval year-month string must be not null") - def toInterval(yearStr: String, monthStr: String): CalendarInterval = { + def toInterval(yearStr: String, monthStr: String, sign: Int): CalendarInterval = { try { - val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE).toInt - val months = toLongWithRange(MONTH, monthStr, 0, 11).toInt - val totalMonths = Math.addExact(Math.multiplyExact(years, 12), months) - new CalendarInterval(totalMonths, 0, 0) + val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR) + val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11)) + new CalendarInterval(Math.toIntExact(totalMonths), 0, 0) } catch { case NonFatal(e) => throw new IllegalArgumentException( @@ -114,9 +121,9 @@ object IntervalUtils { } input.trim match { case yearMonthPattern("-", yearStr, monthStr) => - negateExact(toInterval(yearStr, monthStr)) + toInterval(yearStr, monthStr, -1) case yearMonthPattern(_, yearStr, monthStr) => - toInterval(yearStr, monthStr) + toInterval(yearStr, monthStr, 1) case _ => throw new IllegalArgumentException( s"Interval string does not match year-month format of 'y-m': $input") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a3fbe4c742..446486bdf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1352,4 +1352,9 @@ private[spark] object QueryCompilationErrors { s"Expected udfs have the same evalType but got different evalTypes: " + s"${evalTypes.mkString(",")}") } + + def ambiguousFieldNameError(fieldName: String, names: String): Throwable = { + new AnalysisException( + s"Ambiguous field name: $fieldName. Found multiple columns that can match: $names") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index eb7b7b4ff6..3589c875fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -309,7 +309,7 @@ object QueryExecutionErrors { new IllegalStateException("table stats must be specified.") } - def unaryMinusCauseOverflowError(originValue: Short): ArithmeticException = { + def unaryMinusCauseOverflowError(originValue: AnyVal): ArithmeticException = { new ArithmeticException(s"- $originValue caused overflow.") } @@ -772,4 +772,55 @@ object QueryExecutionErrors { new IllegalArgumentException(s"Unexpected: $o") } + def unscaledValueTooLargeForPrecisionError(): Throwable = { + new ArithmeticException("Unscaled value too large for precision") + } + + def decimalPrecisionExceedsMaxPrecisionError(precision: Int, maxPrecision: Int): Throwable = { + new ArithmeticException( + s"Decimal precision $precision exceeds max precision $maxPrecision") + } + + def outOfDecimalTypeRangeError(str: UTF8String): Throwable = { + new ArithmeticException(s"out of decimal type range: $str") + } + + def unsupportedArrayTypeError(clazz: Class[_]): Throwable = { + new RuntimeException(s"Do not support array of type $clazz.") + } + + def unsupportedJavaTypeError(clazz: Class[_]): Throwable = { + new RuntimeException(s"Do not support type $clazz.") + } + + def failedParsingStructTypeError(raw: String): Throwable = { + new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") + } + + def failedMergingFieldsError(leftName: String, rightName: String, e: Throwable): Throwable = { + new SparkException(s"Failed to merge fields '$leftName' and '$rightName'. ${e.getMessage}") + } + + def cannotMergeDecimalTypesWithIncompatiblePrecisionAndScaleError( + leftPrecision: Int, rightPrecision: Int, leftScale: Int, rightScale: Int): Throwable = { + new SparkException("Failed to merge decimal types with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } + + def cannotMergeDecimalTypesWithIncompatiblePrecisionError( + leftPrecision: Int, rightPrecision: Int): Throwable = { + new SparkException("Failed to merge decimal types with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } + + def cannotMergeDecimalTypesWithIncompatibleScaleError( + leftScale: Int, rightScale: Int): Throwable = { + new SparkException("Failed to merge decimal types with incompatible " + + s"scala $leftScale and $rightScale") + } + + def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { + new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + + s" and ${right.catalogString}") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index d97b19954f..b714f57875 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -367,4 +367,7 @@ object QueryParsingErrors { new ParseException("LOCAL is supported only with file: scheme", ctx) } + def invalidGroupingSetError(element: String, ctx: GroupingAnalyticsContext): Throwable = { + new ParseException(s"Empty set in $element grouping sets is not supported.", ctx) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04e740039f..9d09715d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3150,6 +3150,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters") + .internal() + .doc("Maximum number of output file writers to use concurrently. If number of writers " + + "needed reaches this limit, task will sort rest of output then writing them.") + .version("3.2.0") + .intConf + .createWithDefault(0) + /** * Holds information about keys that have been deprecated. * @@ -3839,6 +3847,8 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 960e174f9c..d9f457f153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,6 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try import org.apache.spark.annotation.Unstable +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String @@ -80,7 +81,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(unscaled: Long, precision: Int, scale: Int): Decimal = { if (setOrNull(unscaled, precision, scale) == null) { - throw new ArithmeticException("Unscaled value too large for precision") + throw QueryExecutionErrors.unscaledValueTooLargeForPrecisionError() } this } @@ -118,8 +119,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { DecimalType.checkNegativeScale(scale) this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP) if (decimalVal.precision > precision) { - throw new ArithmeticException( - s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") + throw QueryExecutionErrors.decimalPrecisionExceedsMaxPrecisionError( + decimalVal.precision, precision) } this.longVal = 0L this._precision = precision @@ -251,7 +252,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toByte: Byte = toLong.toByte private def overflowException(dataType: String) = - throw new ArithmeticException(s"Casting $this to $dataType causes overflow") + throw QueryExecutionErrors.castingCauseOverflowError(this, dataType) /** * @return the Byte value that is equal to the rounded decimal. @@ -263,14 +264,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toByte) { actualLongVal.toByte } else { - overflowException("byte") + throw QueryExecutionErrors.castingCauseOverflowError(this, "byte") } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) { doubleVal.toByte } else { - overflowException("byte") + throw QueryExecutionErrors.castingCauseOverflowError(this, "byte") } } } @@ -285,14 +286,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toShort) { actualLongVal.toShort } else { - overflowException("short") + throw QueryExecutionErrors.castingCauseOverflowError(this, "short") } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) { doubleVal.toShort } else { - overflowException("short") + throw QueryExecutionErrors.castingCauseOverflowError(this, "short") } } } @@ -307,14 +308,14 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (actualLongVal == actualLongVal.toInt) { actualLongVal.toInt } else { - overflowException("int") + throw QueryExecutionErrors.castingCauseOverflowError(this, "int") } } else { val doubleVal = decimalVal.toDouble if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) { doubleVal.toInt } else { - overflowException("int") + throw QueryExecutionErrors.castingCauseOverflowError(this, "int") } } } @@ -333,7 +334,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { // `longValueExact` to make sure the range check is accurate. decimalVal.bigDecimal.toBigInteger.longValueExact() } catch { - case _: ArithmeticException => overflowException("long") + case _: ArithmeticException => + throw QueryExecutionErrors.castingCauseOverflowError(this, "long") } } } @@ -365,8 +367,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (nullOnOverflow) { null } else { - throw new ArithmeticException( - s"$toDebugString cannot be represented as Decimal($precision, $scale).") + throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(this, precision, scale) } } } @@ -622,13 +623,13 @@ object Decimal { // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. // For example: Decimal("6.0790316E+25569151") if (calculatePrecision(bigDecimal) > DecimalType.MAX_PRECISION) { - throw new ArithmeticException(s"out of decimal type range: $str") + throw QueryExecutionErrors.outOfDecimalTypeRangeError(str) } else { Decimal(bigDecimal) } } catch { case _: NumberFormatException => - throw new NumberFormatException(s"invalid input syntax for type numeric: $str") + throw QueryExecutionErrors.invalidInputSyntaxForNumericError(str) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index bedf6ccf44..3e05eda344 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.Stable +import org.apache.spark.sql.errors.QueryExecutionErrors /** @@ -162,13 +163,13 @@ object Metadata { builder.putMetadataArray( key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) case other => - throw new RuntimeException(s"Do not support array of type ${other.getClass}.") + throw QueryExecutionErrors.unsupportedArrayTypeError(other.getClass) } } case (key, JNull) => builder.putNull(key) case (key, other) => - throw new RuntimeException(s"Do not support type ${other.getClass}.") + throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass) } builder.build() } @@ -195,7 +196,7 @@ object Metadata { case x: Metadata => toJsonValue(x.map) case other => - throw new RuntimeException(s"Do not support type ${other.getClass}.") + throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass) } } @@ -222,7 +223,7 @@ object Metadata { case null => 0 case other => - throw new RuntimeException(s"Do not support type ${other.getClass}.") + throw QueryExecutionErrors.unsupportedJavaTypeError(other.getClass) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index a223344e92..8ff0536c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -23,14 +23,13 @@ import scala.util.control.NonFatal import org.json4s.JsonDSL._ -import org.apache.spark.SparkException import org.apache.spark.annotation.Stable -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf /** @@ -333,9 +332,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru if (found.length > 1) { val names = found.map(f => prettyFieldName(normalizedPath :+ f.name)) .mkString("[", ", ", " ]") - throw new AnalysisException( - s"Ambiguous field name: ${prettyFieldName(normalizedPath :+ searchName)}. Found " + - s"multiple columns that can match: $names") + throw QueryCompilationErrors.ambiguousFieldNameError( + prettyFieldName(normalizedPath :+ searchName), names) } else if (found.isEmpty) { None } else { @@ -523,7 +521,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") + case _ => throw QueryExecutionErrors.failedParsingStructTypeError(raw) } } @@ -586,8 +584,7 @@ object StructType extends AbstractDataType { nullable = leftNullable || rightNullable) } catch { case NonFatal(e) => - throw new SparkException(s"Failed to merge fields '$leftName' and " + - s"'$rightName'. " + e.getMessage) + throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e) } } .orElse { @@ -610,14 +607,14 @@ object StructType extends AbstractDataType { if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { DecimalType(leftPrecision, leftScale) } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { - throw new SparkException("Failed to merge decimal types with incompatible " + - s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatiblePrecisionAndScaleError( + leftPrecision, rightPrecision, leftScale, rightScale) } else if (leftPrecision != rightPrecision) { - throw new SparkException("Failed to merge decimal types with incompatible " + - s"precision $leftPrecision and $rightPrecision") + throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatiblePrecisionError( + leftPrecision, rightPrecision) } else { - throw new SparkException("Failed to merge decimal types with incompatible " + - s"scala $leftScale and $rightScale") + throw QueryExecutionErrors.cannotMergeDecimalTypesWithIncompatibleScaleError( + leftScale, rightScale) } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) @@ -627,8 +624,7 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + - s" and ${right.catalogString}") + throw QueryExecutionErrors.cannotMergeIncompatibleDataTypesError(left, right) } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 7026ff7de2..a3e76797b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -21,12 +21,13 @@ import scala.math.Numeric._ import scala.math.Ordering import org.apache.spark.sql.catalyst.util.SQLOrderingUtil +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.Decimal.DecimalIsConflicted private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = { if (res > Byte.MaxValue || res < Byte.MinValue) { - throw new ArithmeticException(s"$x $op $y caused overflow.") + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y) } } @@ -50,7 +51,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr override def negate(x: Byte): Byte = { if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw new ArithmeticException(s"- $x caused overflow.") + throw QueryExecutionErrors.unaryMinusCauseOverflowError(x) } (-x).toByte } @@ -60,7 +61,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = { if (res > Short.MaxValue || res < Short.MinValue) { - throw new ArithmeticException(s"$x $op $y caused overflow.") + throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y) } } @@ -84,7 +85,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor override def negate(x: Short): Short = { if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw new ArithmeticException(s"- $x caused overflow.") + throw QueryExecutionErrors.unaryMinusCauseOverflowError(x) } (-x).toShort } @@ -114,14 +115,11 @@ private[sql] object LongExactNumeric extends LongIsIntegral with Ordering.LongOr if (x == x.toInt) { x.toInt } else { - throw new ArithmeticException(s"Casting $x to int causes overflow") + throw QueryExecutionErrors.castingCauseOverflowError(x, "int") } } private[sql] object FloatExactNumeric extends FloatIsFractional { - private def overflowException(x: Float, dataType: String) = - throw new ArithmeticException(s"Casting $x to $dataType causes overflow") - private val intUpperBound = Int.MaxValue private val intLowerBound = Int.MinValue private val longUpperBound = Long.MaxValue @@ -137,7 +135,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) { x.toInt } else { - overflowException(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, "int") } } @@ -145,7 +143,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) { x.toLong } else { - overflowException(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, "int") } } @@ -153,9 +151,6 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { } private[sql] object DoubleExactNumeric extends DoubleIsFractional { - private def overflowException(x: Double, dataType: String) = - throw new ArithmeticException(s"Casting $x to $dataType causes overflow") - private val intUpperBound = Int.MaxValue private val intLowerBound = Int.MinValue private val longUpperBound = Long.MaxValue @@ -165,7 +160,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional { if (Math.floor(x) <= intUpperBound && Math.ceil(x) >= intLowerBound) { x.toInt } else { - overflowException(x, "int") + throw QueryExecutionErrors.castingCauseOverflowError(x, "int") } } @@ -173,7 +168,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional { if (Math.floor(x) <= longUpperBound && Math.ceil(x) >= longLowerBound) { x.toLong } else { - overflowException(x, "long") + throw QueryExecutionErrors.castingCauseOverflowError(x, "long") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 5d5da795a5..ce8acd1825 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.internal.SQLConf @@ -54,6 +54,8 @@ private[sql] object ArrowUtils { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } case NullType => ArrowType.Null.INSTANCE + case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } @@ -74,6 +76,8 @@ private[sql] object ArrowUtils { case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType + case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 0dbae707a4..169c5d6a31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeConstants, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -197,8 +197,8 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { "1970-01-01", "1972-12-31", "2019-02-16", - "2119-03-16").foreach { timestamp => - val input = LocalDate.parse(timestamp) + "2119-03-16").foreach { date => + val input = LocalDate.parse(date) val result = CatalystTypeConverters.convertToCatalyst(input) val expected = DateTimeUtils.localDateToDays(input) assert(result === expected) @@ -294,4 +294,44 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } } } + + test("SPARK-35204: createToCatalystConverter for date") { + Seq(true, false).foreach { enable => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> enable.toString) { + Seq(-1234, 0, 1234).foreach { days => + val converter = CatalystTypeConverters.createToCatalystConverter(DateType) + + val ld = LocalDate.ofEpochDay(days) + val result1 = converter(ld) + + val d = java.sql.Date.valueOf(ld) + val result2 = converter(d) + + val expected = DateTimeUtils.localDateToDays(ld) + assert(result1 === expected) + assert(result2 === expected) + } + } + } + } + + test("SPARK-35204: createToCatalystConverter for timestamp") { + Seq(true, false).foreach { enable => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> enable.toString) { + Seq(-1234, 0, 1234).foreach { seconds => + val converter = CatalystTypeConverters.createToCatalystConverter(TimestampType) + + val i = Instant.ofEpochSecond(seconds) + val result1 = converter(i) + + val t = new java.sql.Timestamp(seconds * DateTimeConstants.MILLIS_PER_SECOND) + val result2 = converter(t) + + val expected = seconds * DateTimeConstants.MICROS_PER_SECOND + assert(result1 === expected) + assert(result2 === expected) + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 4ac7823502..dc9f92d7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, Inner} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.connector.InMemoryTable +import org.apache.spark.sql.connector.catalog.InMemoryTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 1c849fa21e..f7e57e3b27 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} -import org.apache.spark.sql.connector.InMemoryTableCatalog -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala index ec9480514b..7d6ad3bc60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala @@ -29,8 +29,7 @@ import org.scalatest.matchers.must.Matchers import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, InMemoryTable, InMemoryTableCatalog, Table} import org.apache.spark.sql.types._ class TableLookupCacheSuite extends AnalysisTest with Matchers { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 095894b9ff..aec8725d51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import java.util.TimeZone import scala.language.implicitConversions @@ -28,7 +29,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf @@ -932,7 +932,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Date.valueOf("1970-02-01")), Literal(negateExact(stringToInterval("interval 1 month")))), EmptyRow, - s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}") + s"sequence boundaries: 0 to 2678400000000 by -1 months") // SPARK-32133: Sequence step must be a day interval if start and end values are dates checkExceptionInExpression[IllegalArgumentException](Sequence( @@ -943,6 +943,178 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-35088: Accept ANSI intervals by the Sequence expression") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Duration.ofHours(12))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(Duration.ofHours(12))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Duration.ofHours(-12))), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(Duration.ofHours(-12))), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Period.ofMonths(-1))), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(Period.of(1, 5, 0))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(Period.of(-1, -5, 0))), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Duration.ofDays(1))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2018-01-02 00:00:00.000"), + Timestamp.valueOf("2018-01-03 00:00:00.000"), + Timestamp.valueOf("2018-01-04 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Duration.ofDays(-1))), + Seq( + Timestamp.valueOf("2018-01-04 00:00:00.000"), + Timestamp.valueOf("2018-01-03 00:00:00.000"), + Timestamp.valueOf("2018-01-02 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Period.ofDays(1))), + EmptyRow, s"sequence boundaries: 1514793600000000 to 1515052800000000 by 0") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Period.ofDays(-1))), + EmptyRow, s"sequence boundaries: 1515052800000000 to 1514793600000000 by 0") + + DateTimeTestUtils.withDefaultTimeZone(UTC) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(Period.ofMonths(1))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(Period.ofMonths(1))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(Period.of(1, 5, 0))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(Period.ofDays(2))), + EmptyRow, + "sequence step must be a day year-month interval if start and end values are dates") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(Period.ofMonths(-1))), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -1") + + assert(Sequence( + Cast(Literal("2011-03-01"), DateType), + Cast(Literal("2011-04-01"), DateType), + Option(Literal(Duration.ofHours(1)))).checkInputDataTypes().isFailure) + } + } + test("Sequence with default step") { // +/- 1 for integral type checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 3f3a64ef8d..cf2f5057cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -23,8 +23,8 @@ import java.time.temporal.ChronoUnit import scala.language.implicitConversions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, YearMonthIntervalType} @@ -76,17 +76,17 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hours") { - checkEvaluation(ExtractIntervalHours("0 hours"), 0L) - checkEvaluation(ExtractIntervalHours("1 hour"), 1L) - checkEvaluation(ExtractIntervalHours("-1 hour"), -1L) - checkEvaluation(ExtractIntervalHours("23 hours"), 23L) - checkEvaluation(ExtractIntervalHours("-23 hours"), -23L) + checkEvaluation(ExtractIntervalHours("0 hours"), 0.toByte) + checkEvaluation(ExtractIntervalHours("1 hour"), 1.toByte) + checkEvaluation(ExtractIntervalHours("-1 hour"), -1.toByte) + checkEvaluation(ExtractIntervalHours("23 hours"), 23.toByte) + checkEvaluation(ExtractIntervalHours("-23 hours"), -23.toByte) // Years, months and days must not be taken into account - checkEvaluation(ExtractIntervalHours("100 year 10 months 10 days 10 hours"), 10L) + checkEvaluation(ExtractIntervalHours("100 year 10 months 10 days 10 hours"), 10.toByte) // Minutes should be taken into account - checkEvaluation(ExtractIntervalHours("10 hours 100 minutes"), 11L) - checkEvaluation(ExtractIntervalHours(largeInterval), 11L) - checkEvaluation(ExtractIntervalHours("25 hours"), 1L) + checkEvaluation(ExtractIntervalHours("10 hours 100 minutes"), 11.toByte) + checkEvaluation(ExtractIntervalHours(largeInterval), 11.toByte) + checkEvaluation(ExtractIntervalHours("25 hours"), 1.toByte) } @@ -410,4 +410,40 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DayTimeIntervalType, numType) } } + + test("ANSI: extract years and months") { + Seq(Period.ZERO, + Period.ofMonths(100), + Period.ofMonths(-100), + Period.ofYears(100), + Period.ofYears(-100)).foreach { p => + checkEvaluation(ExtractANSIIntervalYears(Literal(p)), + IntervalUtils.getYears(p.toTotalMonths.toInt)) + checkEvaluation(ExtractANSIIntervalMonths(Literal(p)), + IntervalUtils.getMonths(p.toTotalMonths.toInt)) + } + checkEvaluation(ExtractANSIIntervalYears(Literal(null, YearMonthIntervalType)), null) + checkEvaluation(ExtractANSIIntervalMonths(Literal(null, YearMonthIntervalType)), null) + } + + test("ANSI: extract days, hours, minutes and seconds") { + Seq(Duration.ZERO, + Duration.ofMillis(1L * MILLIS_PER_DAY + 2 * MILLIS_PER_SECOND), + Duration.ofMillis(-1L * MILLIS_PER_DAY + 2 * MILLIS_PER_SECOND), + Duration.ofDays(100), + Duration.ofDays(-100), + Duration.ofHours(-100)).foreach { d => + + checkEvaluation(ExtractANSIIntervalDays(Literal(d)), d.toDays.toInt) + checkEvaluation(ExtractANSIIntervalHours(Literal(d)), (d.toHours % HOURS_PER_DAY).toByte) + checkEvaluation(ExtractANSIIntervalMinutes(Literal(d)), + (d.toMinutes % MINUTES_PER_HOUR).toByte) + checkEvaluation(ExtractANSIIntervalSeconds(Literal(d)), + IntervalUtils.getSeconds(IntervalUtils.durationToMicros(d))) + } + checkEvaluation(ExtractANSIIntervalDays(Literal(null, DayTimeIntervalType)), null) + checkEvaluation(ExtractANSIIntervalHours(Literal(null, DayTimeIntervalType)), null) + checkEvaluation(ExtractANSIIntervalMinutes(Literal(null, DayTimeIntervalType)), null) + checkEvaluation(ExtractANSIIntervalSeconds(Literal(null, DayTimeIntervalType)), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 926628aca9..3d11ff97f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -50,8 +51,10 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { testBothCodegenAndInterpreted("unsafe buffer") { val inputRow = InternalRow.fromSeq(Seq( false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L, Int.MinValue, Long.MaxValue)) - val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length) - val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) + val numFields = fixedLengthTypes.length + val numBytes = Platform.BYTE_ARRAY_OFFSET + UnsafeRow.calculateBitSetWidthInBytes(numFields) + + UnsafeRow.WORD_SIZE * numFields + val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, numFields) val proj = createMutableProjection(fixedLengthTypes) val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala index 7895f4d5ef..2fab553183 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala @@ -18,39 +18,36 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE import org.apache.spark.sql.types._ class SchemaPruningSuite extends SparkFunSuite with SQLHelper { - - def getRootFields(requestedFields: StructField*): Seq[RootField] = { - requestedFields.map { f => + private def testPrunedSchema( + schema: StructType, + requestedFields: Seq[StructField], + expectedSchema: StructType): Unit = { + val requestedRootFields = requestedFields.map { f => // `derivedFromAtt` doesn't affect the result of pruned schema. SchemaPruning.RootField(field = f, derivedFromAtt = true) } + val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields) + assert(prunedSchema === expectedSchema) } test("prune schema by the requested fields") { - def testPrunedSchema( - schema: StructType, - requestedFields: StructField*): Unit = { - val requestedRootFields = requestedFields.map { f => - // `derivedFromAtt` doesn't affect the result of pruned schema. - SchemaPruning.RootField(field = f, derivedFromAtt = true) - } - val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields) - assert(expectedSchema == StructType(requestedFields)) - } - - testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType)) - testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType)) + testPrunedSchema( + StructType.fromDDL("a int, b int"), + Seq(StructField("a", IntegerType)), + StructType.fromDDL("a int, b int")) val structOfStruct = StructType.fromDDL("a struct, b int") - testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int"))) - testPrunedSchema(structOfStruct, StructField("b", IntegerType)) - testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int"))) + testPrunedSchema(structOfStruct, + Seq(StructField("a", StructType.fromDDL("a int")), StructField("b", IntegerType)), + StructType.fromDDL("a struct, b int")) + testPrunedSchema(structOfStruct, + Seq(StructField("a", StructType.fromDDL("a int"))), + StructType.fromDDL("a struct, b int")) val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string"))) val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"), @@ -60,44 +57,76 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper { arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) :: mapOfStruct :: Nil) - testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))), - StructField("b", StructType.fromDDL("a int"))) testPrunedSchema(complexStruct, - StructField("a", ArrayType(StructType.fromDDL("b int, c string"))), - StructField("b", StructType.fromDDL("b int"))) + Seq(StructField("a", ArrayType(StructType.fromDDL("b int"))), + StructField("b", StructType.fromDDL("a int"))), + StructType( + StructField("a", ArrayType(StructType.fromDDL("b int"))) :: + StructField("b", StructType.fromDDL("a int")) :: + StructField("c", IntegerType) :: + mapOfStruct :: Nil)) + testPrunedSchema(complexStruct, + Seq(StructField("a", ArrayType(StructType.fromDDL("b int, c string"))), + StructField("b", StructType.fromDDL("b int"))), + StructType( + StructField("a", ArrayType(StructType.fromDDL("b int, c string"))) :: + StructField("b", StructType.fromDDL("b int")) :: + StructField("c", IntegerType) :: + mapOfStruct :: Nil)) val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"), StructType.fromDDL("e int, f string"))) - testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap) + testPrunedSchema(complexStruct, + Seq(StructField("c", IntegerType), selectFieldInMap), + StructType( + arrayOfStruct :: + StructField("b", structOfStruct) :: + StructField("c", IntegerType) :: + selectFieldInMap :: Nil)) } test("SPARK-35096: test case insensitivity of pruned schema") { - Seq(true, false).foreach(isCaseSensitive => { + val upperCaseSchema = StructType.fromDDL("A struct, B int") + val lowerCaseSchema = StructType.fromDDL("a struct, b int") + val upperCaseRequestedFields = Seq(StructField("A", StructType.fromDDL("A int"))) + val lowerCaseRequestedFields = Seq(StructField("a", StructType.fromDDL("a int"))) + + Seq(true, false).foreach { isCaseSensitive => withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) { if (isCaseSensitive) { - // Schema is case-sensitive - val requestedFields = getRootFields(StructField("id", IntegerType)) - val prunedSchema = SchemaPruning.pruneDataSchema( - StructType.fromDDL("ID int, name String"), requestedFields) - assert(prunedSchema == StructType(Seq.empty)) - // Root fields are case-sensitive - val rootFieldsSchema = SchemaPruning.pruneDataSchema( - StructType.fromDDL("id int, name String"), - getRootFields(StructField("ID", IntegerType))) - assert(rootFieldsSchema == StructType(StructType(Seq.empty))) + testPrunedSchema( + upperCaseSchema, + upperCaseRequestedFields, + StructType.fromDDL("A struct, B int")) + testPrunedSchema( + upperCaseSchema, + lowerCaseRequestedFields, + upperCaseSchema) + + testPrunedSchema( + lowerCaseSchema, + upperCaseRequestedFields, + lowerCaseSchema) + testPrunedSchema( + lowerCaseSchema, + lowerCaseRequestedFields, + StructType.fromDDL("a struct, b int")) } else { - // Schema is case-insensitive - val prunedSchema = SchemaPruning.pruneDataSchema( - StructType.fromDDL("ID int, name String"), - getRootFields(StructField("id", IntegerType))) - assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil)) - // Root fields are case-insensitive - val rootFieldsSchema = SchemaPruning.pruneDataSchema( - StructType.fromDDL("id int, name String"), - getRootFields(StructField("ID", IntegerType))) - assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil)) + Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields => + testPrunedSchema( + upperCaseSchema, + requestedFields, + StructType.fromDDL("A struct, B int")) + } + + Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields => + testPrunedSchema( + lowerCaseSchema, + requestedFields, + StructType.fromDDL("a struct, b int")) + } } } - }) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala index 441c15340a..997ccb7204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala @@ -31,8 +31,10 @@ class CombineConcatsSuite extends PlanTest { } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { - val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze - val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze) + val correctAnswer = Limit(Literal(1), Project(Alias(e2, "out")() :: Nil, OneRowRelation())) + .analyze + val actual = Optimize.execute(Limit(Literal(1), Project(Alias(e1, "out")() :: Nil, + OneRowRelation())).analyze) comparePlans(actual, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index 0ae4d3f6e6..a856caa678 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -329,14 +329,14 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized, expected) } - test("Nested field pruning for Project and Generate: not prune on generator output") { + test("Nested field pruning for Project and Generate: multiple-field case is not supported") { val companies = LocalRelation( 'id.int, 'employers.array(employer)) val query = companies .generate(Explode('employers.getField("company")), outputNames = Seq("company")) - .select('company.getField("name")) + .select('company.getField("name"), 'company.getField("address")) .analyze val optimized = Optimize.execute(query) @@ -347,7 +347,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0), outputNames = Seq("company")) - .select('company.getField("name").as("company.name")) + .select('company.getField("name").as("company.name"), + 'company.getField("address").as("company.address")) .analyze comparePlans(optimized, expected) } @@ -684,6 +685,29 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { ).analyze comparePlans(optimized2, expected2) } + + test("SPARK-34638: nested column prune on generator output for one field") { + val companies = LocalRelation( + 'id.int, + 'employers.array(employer)) + + val query = companies + .generate(Explode('employers.getField("company")), outputNames = Seq("company")) + .select('company.getField("name")) + .analyze + val optimized = Optimize.execute(query) + + val aliases = collectGeneratedAliases(optimized) + + val expected = companies + .select('employers.getField("company").getField("name").as(aliases(0))) + .generate(Explode($"${aliases(0)}"), + unrequiredChildIndex = Seq(0), + outputNames = Seq("company")) + .select('company.as("company.name")) + .analyze + comparePlans(optimized, expected) + } } object NestedColumnAliasingSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index b093b39cc4..e63742ac0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") { + val originalQuery = testRelation + .select( + Alias(UpdateFields('a, + WithField("a1", Literal(3)) :: + WithField("b1", Literal(4)) :: + WithField("a1", Literal(5)) :: + Nil), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select( + Alias(UpdateFields('a, + WithField("a1", Literal(5)) :: + WithField("b1", Literal(4)) :: + Nil), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala index 3eba003d77..d376c31ef9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy('a + 'b)(('a + 'b) as 'c) .analyze val optimized = Optimize.execute(query) - comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected)) + comparePlans(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d149967094..dcd2fbbf00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -36,8 +36,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = - Batch("Finish Analysis", Once, - EnforceGroupingReferencesInAggregates) :: Batch("collapse projections", FixedPoint(10), CollapseProject) :: Batch("Constant Folding", FixedPoint(10), @@ -59,7 +57,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { val optimized = Optimizer.execute(originalQuery.analyze) assert(optimized.resolved, "optimized plans must be still resolvable") - comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze)) + comparePlans(optimized, correctAnswer.analyze) } test("explicit get from namedStruct") { @@ -407,6 +405,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val arrayAggRel = relation.groupBy( CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) } test("SPARK-23500: namedStruct and getField in the same Project #1") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 5c460f70a9..87d306a495 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -169,6 +169,19 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { fromYearMonthString) failFuncWithInvalidInput("-\t99-15", "Interval string does not match year-month format", fromYearMonthString) + + assert(fromYearMonthString("178956970-6") == new CalendarInterval(Int.MaxValue - 1, 0, 0)) + assert(fromYearMonthString("178956970-7") == new CalendarInterval(Int.MaxValue, 0, 0)) + + val e1 = intercept[IllegalArgumentException]{ + assert(fromYearMonthString("178956970-8") == new CalendarInterval(Int.MinValue, 0, 0)) + }.getMessage + assert(e1.contains("integer overflow")) + assert(fromYearMonthString("-178956970-8") == new CalendarInterval(Int.MinValue, 0, 0)) + val e2 = intercept[IllegalArgumentException]{ + assert(fromYearMonthString("-178956970-9") == new CalendarInterval(Int.MinValue, 0, 0)) + }.getMessage + assert(e2.contains("integer overflow")) } test("from day-time string - legacy") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala index aec361b979..eb35dd47a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogManagerSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2SessionCatalog, NoSuchNamespaceException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.connector.InMemoryTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala new file mode 100644 index 0000000000..a48eb04a98 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryAtomicPartitionTable.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException, PartitionsAlreadyExistException} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +/** + * This class is used to test SupportsAtomicPartitionManagement API. + */ +class InMemoryAtomicPartitionTable ( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryPartitionTable(name, schema, partitioning, properties) + with SupportsAtomicPartitionManagement { + + override def createPartition( + ident: InternalRow, + properties: util.Map[String, String]): Unit = { + if (memoryTablePartitions.containsKey(ident)) { + throw new PartitionAlreadyExistsException(name, ident, partitionSchema) + } else { + createPartitionKey(ident.toSeq(schema)) + memoryTablePartitions.put(ident, properties) + } + } + + override def dropPartition(ident: InternalRow): Boolean = { + if (memoryTablePartitions.containsKey(ident)) { + memoryTablePartitions.remove(ident) + removePartitionKey(ident.toSeq(schema)) + true + } else { + false + } + } + + override def createPartitions( + idents: Array[InternalRow], + properties: Array[util.Map[String, String]]): Unit = { + if (idents.exists(partitionExists)) { + throw new PartitionsAlreadyExistException( + name, idents.filter(partitionExists), partitionSchema) + } + idents.zip(properties).foreach { case (ident, property) => + createPartition(ident, property) + } + } + + override def dropPartitions(idents: Array[InternalRow]): Boolean = { + if (!idents.forall(partitionExists)) { + return false; + } + idents.forall(dropPartition) + } + + override def truncatePartitions(idents: Array[InternalRow]): Boolean = { + val nonExistent = idents.filterNot(partitionExists) + if (nonExistent.isEmpty) { + idents.foreach(truncatePartition) + true + } else { + throw new NoSuchPartitionException(name, nonExistent.head, partitionSchema) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala new file mode 100644 index 0000000000..58dc484711 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTable.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +/** + * This class is used to test SupportsPartitionManagement API. + */ +class InMemoryPartitionTable( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryTable(name, schema, partitioning, properties) with SupportsPartitionManagement { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + protected val memoryTablePartitions: util.Map[InternalRow, util.Map[String, String]] = + new ConcurrentHashMap[InternalRow, util.Map[String, String]]() + + def partitionSchema: StructType = { + val partitionColumnNames = partitioning.toSeq.asPartitionColumns + new StructType(schema.filter(p => partitionColumnNames.contains(p.name)).toArray) + } + + def createPartition( + ident: InternalRow, + properties: util.Map[String, String]): Unit = { + if (memoryTablePartitions.containsKey(ident)) { + throw new PartitionAlreadyExistsException(name, ident, partitionSchema) + } else { + createPartitionKey(ident.toSeq(schema)) + memoryTablePartitions.put(ident, properties) + } + } + + def dropPartition(ident: InternalRow): Boolean = { + if (memoryTablePartitions.containsKey(ident)) { + memoryTablePartitions.remove(ident) + removePartitionKey(ident.toSeq(schema)) + true + } else { + false + } + } + + def replacePartitionMetadata(ident: InternalRow, properties: util.Map[String, String]): Unit = { + if (memoryTablePartitions.containsKey(ident)) { + memoryTablePartitions.put(ident, properties) + } else { + throw new NoSuchPartitionException(name, ident, partitionSchema) + } + } + + def loadPartitionMetadata(ident: InternalRow): util.Map[String, String] = { + if (memoryTablePartitions.containsKey(ident)) { + memoryTablePartitions.get(ident) + } else { + throw new NoSuchPartitionException(name, ident, partitionSchema) + } + } + + override protected def addPartitionKey(key: Seq[Any]): Unit = { + memoryTablePartitions.putIfAbsent(InternalRow.fromSeq(key), Map.empty[String, String].asJava) + } + + override def listPartitionIdentifiers( + names: Array[String], + ident: InternalRow): Array[InternalRow] = { + assert(names.length == ident.numFields, + s"Number of partition names (${names.length}) must be equal to " + + s"the number of partition values (${ident.numFields}).") + val schema = partitionSchema + assert(names.forall(fieldName => schema.fieldNames.contains(fieldName)), + s"Some partition names ${names.mkString("[", ", ", "]")} don't belong to " + + s"the partition schema '${schema.sql}'.") + val indexes = names.map(schema.fieldIndex) + val dataTypes = names.map(schema(_).dataType) + val currentRow = new GenericInternalRow(new Array[Any](names.length)) + memoryTablePartitions.keySet().asScala.filter { key => + for (i <- 0 until names.length) { + currentRow.values(i) = key.get(indexes(i), dataTypes(i)) + } + currentRow == ident + }.toArray + } + + override def renamePartition(from: InternalRow, to: InternalRow): Boolean = { + if (memoryTablePartitions.containsKey(to)) { + throw new PartitionAlreadyExistsException(name, to, partitionSchema) + } else { + val partValue = memoryTablePartitions.remove(from) + if (partValue == null) { + throw new NoSuchPartitionException(name, from, partitionSchema) + } + memoryTablePartitions.put(to, partValue) == null && + renamePartitionKey(partitionSchema, from.toSeq(schema), to.toSeq(schema)) + } + } + + override def truncatePartition(ident: InternalRow): Boolean = { + if (memoryTablePartitions.containsKey(ident)) { + clearPartition(ident.toSeq(schema)) + true + } else { + throw new NoSuchPartitionException(name, ident, partitionSchema) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala new file mode 100644 index 0000000000..a24f5c9a0c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryPartitionTableCatalog.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +class InMemoryPartitionTableCatalog extends InMemoryTableCatalog { + import CatalogV2Implicits._ + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val table = new InMemoryAtomicPartitionTable( + s"$name.${ident.quoted}", schema, partitions, properties) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala new file mode 100644 index 0000000000..b9069ff311 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -0,0 +1,535 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoUnit +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.scalatest.Assertions._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String + +/** + * A simple in-memory table. Rows are stored as a buffered group produced by each output task. + */ +class InMemoryTable( + val name: String, + val schema: StructType, + override val partitioning: Array[Transform], + override val properties: util.Map[String, String], + val distribution: Distribution = Distributions.unspecified(), + val ordering: Array[SortOrder] = Array.empty, + val numPartitions: Option[Int] = None) + extends Table with SupportsRead with SupportsWrite with SupportsDelete + with SupportsMetadataColumns { + + private object PartitionKeyColumn extends MetadataColumn { + override def name: String = "_partition" + override def dataType: DataType = StringType + override def comment: String = "Partition key used to store the row" + } + + private object IndexColumn extends MetadataColumn { + override def name: String = "index" + override def dataType: DataType = IntegerType + override def comment: String = "Metadata column used to conflict with a data column" + } + + // purposely exposes a metadata column that conflicts with a data column in some tests + override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) + private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name) + + private val allowUnsupportedTransforms = + properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + + partitioning.foreach { + case _: IdentityTransform => + case _: YearsTransform => + case _: MonthsTransform => + case _: DaysTransform => + case _: HoursTransform => + case _: BucketTransform => + case t if !allowUnsupportedTransforms => + throw new IllegalArgumentException(s"Transform $t is not a supported transform") + } + + // The key `Seq[Any]` is the partition values. + val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty + + def data: Array[BufferedRows] = dataMap.values.toArray + + def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq + + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } + + private val UTC = ZoneId.of("UTC") + private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate + + private def getKey(row: InternalRow): Seq[Any] = { + def extractor( + fieldNames: Array[String], + schema: StructType, + row: InternalRow): (Any, DataType) = { + val index = schema.fieldIndex(fieldNames(0)) + val value = row.toSeq(schema).apply(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.drop(1), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") + } + } else { + (value, schema(index).dataType) + } + } + + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) + partitioning.map { + case IdentityTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row)._1 + case YearsTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days: Int, DateType) => + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case MonthsTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days: Int, DateType) => + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case DaysTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days, DateType) => + days + case (micros: Long, TimestampType) => + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case HoursTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (micros: Long, TimestampType) => + ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case BucketTransform(numBuckets, ref) => + val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) + val valueHashCode = if (value == null) 0 else value.hashCode + ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets + } + } + + protected def addPartitionKey(key: Seq[Any]): Unit = {} + + protected def renamePartitionKey( + partitionSchema: StructType, + from: Seq[Any], + to: Seq[Any]): Boolean = { + val rows = dataMap.remove(from).getOrElse(new BufferedRows(from.mkString("/"))) + val newRows = new BufferedRows(to.mkString("/")) + rows.rows.foreach { r => + val newRow = new GenericInternalRow(r.numFields) + for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) + for (i <- 0 until partitionSchema.length) { + val j = schema.fieldIndex(partitionSchema(i).name) + newRow.update(j, to(i)) + } + newRows.withRow(newRow) + } + dataMap.put(to, newRows).foreach { _ => + throw new IllegalStateException( + s"The ${to.mkString("[", ", ", "]")} partition exists already") + } + true + } + + protected def removePartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { + dataMap.remove(key) + } + + protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { + if (!dataMap.contains(key)) { + val emptyRows = new BufferedRows(key.toArray.mkString("/")) + val rows = if (key.length == schema.length) { + emptyRows.withRow(InternalRow.fromSeq(key)) + } else emptyRows + dataMap.put(key, rows) + } + } + + protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized { + assert(dataMap.contains(key)) + dataMap(key).clear() + } + + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { + data.foreach(_.rows.foreach { row => + val key = getKey(row) + dataMap += dataMap.get(key) + .map(key -> _.withRow(row)) + .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + addPartitionKey(key) + }) + this + } + + override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.STREAMING_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC, + TableCapability.TRUNCATE).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema) + } + + class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder + with SupportsPushDownRequiredColumns { + private var schema: StructType = tableSchema + + override def build: Scan = + new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema) + + override def pruneColumns(requiredSchema: StructType): Unit = { + // if metadata columns are projected, return the table schema and metadata columns + val hasMetadataColumns = requiredSchema.map(_.name).exists(metadataColumnNames.contains) + if (hasMetadataColumns) { + schema = StructType(tableSchema ++ metadataColumnNames + .flatMap(name => metadataColumns.find(_.name == name)) + .map(col => StructField(col.name, col.dataType, col.isNullable))) + } + } + } + + class InMemoryBatchScan(data: Array[InputPartition], schema: StructType) extends Scan with Batch { + override def readSchema(): StructType = schema + + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = data + + override def createReaderFactory(): PartitionReaderFactory = { + val metadataColumns = schema.map(_.name).filter(metadataColumnNames.contains) + new BufferedRowsReaderFactory(metadataColumns) + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + InMemoryTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) + InMemoryTable.maybeSimulateFailedTableWrite(info.options) + + new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { + private var writer: BatchWrite = Append + private var streamingWriter: StreamingWrite = StreamingAppend + + override def truncate(): WriteBuilder = { + assert(writer == Append) + writer = TruncateAndAppend + streamingWriter = StreamingTruncateAndAppend + this + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)") + this + } + + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") + this + } + + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = distribution + + override def requiredOrdering: Array[SortOrder] = ordering + + override def requiredNumPartitions(): Int = { + numPartitions.getOrElse(0) + } + + override def toBatch: BatchWrite = writer + + override def toStreaming: StreamingWrite = streamingWriter match { + case exc: StreamingNotSupportedOperation => exc.throwsException() + case s => s + } + } + } + } + + private abstract class TestBatchWrite extends BatchWrite { + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + + private object Append extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private object DynamicOverwrite extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + dataMap --= newData.flatMap(_.rows.map(getKey)) + withData(newData) + } + } + + private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private object TruncateAndAppend extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private abstract class TestStreamingWrite extends StreamingWrite { + def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { + BufferedRowsWriterFactory + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + } + + private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { + override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = + throwsException() + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + def throwsException[T](): T = throw new IllegalStateException("The operation " + + s"${operation} isn't supported for streaming query.") + } + + private object StreamingAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + + private object StreamingTruncateAndAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + + override def canDeleteWhere(filters: Array[Filter]): Boolean = { + InMemoryTable.supportsFilters(filters) + } + + override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) + } +} + +object InMemoryTable { + val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite" + + def filtersToKeys( + keys: Iterable[Seq[Any]], + partitionNames: Seq[String], + filters: Array[Filter]): Iterable[Seq[Any]] = { + keys.filter { partValues => + filters.flatMap(splitAnd).forall { + case EqualTo(attr, value) => + value == extractValue(attr, partitionNames, partValues) + case EqualNullSafe(attr, value) => + val attrVal = extractValue(attr, partitionNames, partValues) + if (attrVal == null && value === null) { + true + } else if (attrVal == null || value === null) { + false + } else { + value == attrVal + } + case IsNull(attr) => + null == extractValue(attr, partitionNames, partValues) + case IsNotNull(attr) => + null != extractValue(attr, partitionNames, partValues) + case AlwaysTrue() => true + case f => + throw new IllegalArgumentException(s"Unsupported filter type: $f") + } + } + } + + def supportsFilters(filters: Array[Filter]): Boolean = { + filters.flatMap(splitAnd).forall { + case _: EqualTo => true + case _: EqualNullSafe => true + case _: IsNull => true + case _: IsNotNull => true + case _: AlwaysTrue => true + case _ => false + } + } + + private def extractValue( + attr: String, + partFieldNames: Seq[String], + partValues: Seq[Any]): Any = { + partFieldNames.zipWithIndex.find(_._1 == attr) match { + case Some((_, partIndex)) => + partValues(partIndex) + case _ => + throw new IllegalArgumentException(s"Unknown filter attribute: $attr") + } + } + + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } + + def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = { + if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) { + throw new IllegalStateException("Manual write to table failure.") + } + } +} + +class BufferedRows( + val key: String = "") extends WriterCommitMessage with InputPartition with Serializable { + val rows = new mutable.ArrayBuffer[InternalRow]() + + def withRow(row: InternalRow): BufferedRows = { + rows.append(row) + this + } + + def clear(): Unit = rows.clear() +} + +private class BufferedRowsReaderFactory( + metadataColumns: Seq[String]) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new BufferedRowsReader(partition.asInstanceOf[BufferedRows], metadataColumns) + } +} + +private class BufferedRowsReader( + partition: BufferedRows, + metadataColumns: Seq[String]) extends PartitionReader[InternalRow] { + private def addMetadata(row: InternalRow): InternalRow = { + val metadataRow = new GenericInternalRow(metadataColumns.map { + case "index" => index + case "_partition" => UTF8String.fromString(partition.key) + }.toArray) + new JoinedRow(row, metadataRow) + } + + private var index: Int = -1 + + override def next(): Boolean = { + index += 1 + index < partition.rows.length + } + + override def get(): InternalRow = addMetadata(partition.rows(index)) + + override def close(): Unit = {} +} + +private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new BufferWriter + } + + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + new BufferWriter + } +} + +private class BufferWriter extends DataWriter[InternalRow] { + private val buffer = new BufferedRows + + override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) + + override def commit(): WriterCommitMessage = buffer + + override def abort(): Unit = {} + + override def close(): Unit = {} +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala new file mode 100644 index 0000000000..38113f9ea1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class BasicInMemoryTableCatalog extends TableCatalog { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + protected val namespaces: util.Map[List[String], Map[String, String]] = + new ConcurrentHashMap[List[String], Map[String, String]]() + + protected val tables: util.Map[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() + + private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() + + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def invalidateTable(ident: Identifier): Unit = { + invalidatedTables.add(ident) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + createTable(ident, schema, partitions, properties, Distributions.unspecified(), + Array.empty, None) + } + + def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + distribution: Distribution, + ordering: Array[SortOrder], + requiredNumPartitions: Option[Int]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, + ordering, requiredNumPartitions) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val table = loadTable(ident).asInstanceOf[InMemoryTable] + val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes) + val schema = CatalogV2Util.applySchemaChanges(table.schema, changes) + + // fail if the last column in the schema was dropped + if (schema.fields.isEmpty) { + throw new IllegalArgumentException(s"Cannot drop all fields") + } + + val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties) + .withData(table.data) + + tables.put(ident, newTable) + + newTable + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + if (tables.containsKey(newIdent)) { + throw new TableAlreadyExistsException(newIdent) + } + + Option(tables.remove(oldIdent)) match { + case Some(table) => + tables.put(newIdent, table) + case _ => + throw new NoSuchTableException(oldIdent) + } + } + + def isTableInvalidated(ident: Identifier): Boolean = { + invalidatedTables.contains(ident) + } + + def clearTables(): Unit = { + tables.clear() + } +} + +class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamespaces { + private def allNamespaces: Seq[Seq[String]] = { + (tables.keySet.asScala.map(_.namespace.toSeq) ++ namespaces.keySet.asScala).toSeq.distinct + } + + override def namespaceExists(namespace: Array[String]): Boolean = { + allNamespaces.exists(_.startsWith(namespace)) + } + + override def listNamespaces: Array[Array[String]] = { + allNamespaces.map(_.head).distinct.map(Array(_)).toArray + } + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + allNamespaces + .filter(_.size > namespace.length) + .filter(_.startsWith(namespace)) + .map(_.take(namespace.length + 1)) + .distinct + .map(_.toArray) + .toArray + } + + override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = { + Option(namespaces.get(namespace.toSeq)) match { + case Some(metadata) => + metadata.asJava + case _ if namespaceExists(namespace) => + util.Collections.emptyMap[String, String] + case _ => + throw new NoSuchNamespaceException(namespace) + } + } + + override def createNamespace( + namespace: Array[String], + metadata: util.Map[String, String]): Unit = { + if (namespaceExists(namespace)) { + throw new NamespaceAlreadyExistsException(namespace) + } + + Option(namespaces.putIfAbsent(namespace.toList, metadata.asScala.toMap)) match { + case Some(_) => + throw new NamespaceAlreadyExistsException(namespace) + case _ => + // created successfully + } + } + + override def alterNamespace( + namespace: Array[String], + changes: NamespaceChange*): Unit = { + val metadata = loadNamespaceMetadata(namespace).asScala.toMap + namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes)) + } + + override def dropNamespace(namespace: Array[String]): Boolean = { + listNamespaces(namespace).foreach(dropNamespace) + try { + listTables(namespace).foreach(dropTable) + } catch { + case _: NoSuchNamespaceException => + } + Option(namespaces.remove(namespace.toList)).isDefined + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + if (namespace.isEmpty || namespaceExists(namespace)) { + super.listTables(namespace) + } else { + throw new NoSuchNamespaceException(namespace) + } + } +} + +object InMemoryTableCatalog { + val SIMULATE_FAILED_CREATE_PROPERTY = "spark.sql.test.simulateFailedCreate" + val SIMULATE_DROP_BEFORE_REPLACE_PROPERTY = "spark.sql.test.simulateDropBeforeReplace" + + def maybeSimulateFailedTableCreation(tableProperties: util.Map[String, String]): Unit = { + if ("true".equalsIgnoreCase(tableProperties.get(SIMULATE_FAILED_CREATE_PROPERTY))) { + throw new IllegalStateException("Manual create table failure.") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala new file mode 100644 index 0000000000..954650ae0e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTableCatalog { + import InMemoryTableCatalog._ + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + validateStagedTable(partitions, properties) + new TestStagedCreateTable( + ident, + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + validateStagedTable(partitions, properties) + new TestStagedReplaceTable( + ident, + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + validateStagedTable(partitions, properties) + new TestStagedCreateOrReplaceTable( + ident, + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) + } + + private def validateStagedTable( + partitions: Array[Transform], + properties: util.Map[String, String]): Unit = { + if (partitions.nonEmpty) { + throw new UnsupportedOperationException( + s"Catalog $name: Partitioned tables are not supported") + } + + maybeSimulateFailedTableCreation(properties) + } + + private abstract class TestStagedTable( + ident: Identifier, + delegateTable: InMemoryTable) + extends StagedTable with SupportsWrite with SupportsRead { + + override def abortStagedChanges(): Unit = {} + + override def name(): String = delegateTable.name + + override def schema(): StructType = delegateTable.schema + + override def capabilities(): util.Set[TableCapability] = delegateTable.capabilities + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + delegateTable.newWriteBuilder(info) + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + delegateTable.newScanBuilder(options) + } + } + + private class TestStagedCreateTable( + ident: Identifier, + delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) { + + override def commitStagedChanges(): Unit = { + val maybePreCommittedTable = tables.putIfAbsent(ident, delegateTable) + if (maybePreCommittedTable != null) { + throw new TableAlreadyExistsException( + s"Table with identifier $ident and name $name was already created.") + } + } + } + + private class TestStagedReplaceTable( + ident: Identifier, + delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) { + + override def commitStagedChanges(): Unit = { + maybeSimulateDropBeforeCommit() + val maybePreCommittedTable = tables.replace(ident, delegateTable) + if (maybePreCommittedTable == null) { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + } + + private def maybeSimulateDropBeforeCommit(): Unit = { + if ("true".equalsIgnoreCase( + delegateTable.properties.get(SIMULATE_DROP_BEFORE_REPLACE_PROPERTY))) { + tables.remove(ident) + } + } + } + + private class TestStagedCreateOrReplaceTable( + ident: Identifier, + delegateTable: InMemoryTable) extends TestStagedTable(ident, delegateTable) { + + override def commitStagedChanges(): Unit = { + tables.put(ident, delegateTable) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala index ecfc6adff7..df2fbd6d17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagementSuite.scala @@ -22,7 +22,6 @@ import java.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionsAlreadyExistException} -import org.apache.spark.sql.connector.{BufferedRows, InMemoryAtomicPartitionTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference} import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index c95c459721..e5aeb90b84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} -import org.apache.spark.sql.connector.{BufferedRows, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference} import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala index 485e41f9eb..5560bda928 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TableCatalogSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.connector.{BufferedRows, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index e6565feebf..5ae74c5eaf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -48,6 +48,8 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) roundtrip(DateType) + roundtrip(YearMonthIntervalType) + roundtrip(DayTimeIntervalType) val tsExMsg = intercept[UnsupportedOperationException] { roundtrip(TimestampType) } diff --git a/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt b/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt index 546face681..7b1e82d64a 100644 --- a/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/AggregateBenchmark-jdk11-results.txt @@ -2,142 +2,147 @@ aggregate without grouping ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz agg w/o group: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -agg w/o group wholestage off 63666 64021 502 32.9 30.4 1.0X -agg w/o group wholestage on 882 912 37 2376.9 0.4 72.2X +agg w/o group wholestage off 82274 82877 853 25.5 39.2 1.0X +agg w/o group wholestage on 1322 1358 37 1586.7 0.6 62.2X ================================================================================================ stat functions ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz stddev: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -stddev wholestage off 7370 7688 450 14.2 70.3 1.0X -stddev wholestage on 931 997 50 112.6 8.9 7.9X +stddev wholestage off 8975 9129 219 11.7 85.6 1.0X +stddev wholestage on 1424 1444 34 73.6 13.6 6.3X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz kurtosis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -kurtosis wholestage off 30901 31209 436 3.4 294.7 1.0X -kurtosis wholestage on 950 996 33 110.4 9.1 32.5X +kurtosis wholestage off 42273 42424 213 2.5 403.1 1.0X +kurtosis wholestage on 1492 1528 27 70.3 14.2 28.3X ================================================================================================ aggregate with linear keys ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 8845 8874 41 9.5 105.4 1.0X -codegen = T hashmap = F 5804 5854 47 14.5 69.2 1.5X -codegen = T hashmap = T 954 1001 35 87.9 11.4 9.3X +codegen = F 10873 10998 176 7.7 129.6 1.0X +codegen = T, hashmap = F 5906 6005 95 14.2 70.4 1.8X +codegen = T, row-based hashmap = T 2325 2410 94 36.1 27.7 4.7X +codegen = T, vectorized hashmap = T 1185 1259 78 70.8 14.1 9.2X ================================================================================================ aggregate with randomized keys ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 10398 10788 552 8.1 124.0 1.0X -codegen = T hashmap = F 7426 7520 84 11.3 88.5 1.4X -codegen = T hashmap = T 1883 1917 31 44.5 22.4 5.5X +codegen = F 12385 12470 120 6.8 147.6 1.0X +codegen = T, hashmap = F 7734 8110 378 10.8 92.2 1.6X +codegen = T, row-based hashmap = T 3663 3702 37 22.9 43.7 3.4X +codegen = T, vectorized hashmap = T 2532 2621 54 33.1 30.2 4.9X ================================================================================================ aggregate with string key ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Aggregate w string key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 3615 3888 386 5.8 172.4 1.0X -codegen = T hashmap = F 2253 2381 168 9.3 107.4 1.6X -codegen = T hashmap = T 1242 1316 59 16.9 59.2 2.9X +codegen = F 4465 4517 73 4.7 212.9 1.0X +codegen = T, hashmap = F 2667 2825 208 7.9 127.2 1.7X +codegen = T, row-based hashmap = T 1436 1466 21 14.6 68.5 3.1X +codegen = T, vectorized hashmap = T 1297 1301 5 16.2 61.8 3.4X ================================================================================================ aggregate with decimal key ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Aggregate w decimal key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 3437 3534 137 6.1 163.9 1.0X -codegen = T hashmap = F 2122 2226 147 9.9 101.2 1.6X -codegen = T hashmap = T 638 678 36 32.9 30.4 5.4X +codegen = F 3722 3746 34 5.6 177.5 1.0X +codegen = T, hashmap = F 2229 2297 96 9.4 106.3 1.7X +codegen = T, row-based hashmap = T 927 957 28 22.6 44.2 4.0X +codegen = T, vectorized hashmap = T 772 796 22 27.2 36.8 4.8X ================================================================================================ aggregate with multiple key types ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz Aggregate w multiple keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 6549 6648 140 3.2 312.3 1.0X -codegen = T hashmap = F 3591 3693 144 5.8 171.2 1.8X -codegen = T hashmap = T 2822 2922 141 7.4 134.6 2.3X +codegen = F 7013 7060 67 3.0 334.4 1.0X +codegen = T, hashmap = F 3750 3894 205 5.6 178.8 1.9X +codegen = T, row-based hashmap = T 2948 2952 5 7.1 140.6 2.4X +codegen = T, vectorized hashmap = T 2986 3145 226 7.0 142.4 2.3X ================================================================================================ max function bytecode size of wholestagecodegen ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz max function bytecode size: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 531 571 36 1.2 810.7 1.0X -codegen = T hugeMethodLimit = 10000 223 282 36 2.9 340.1 2.4X -codegen = T hugeMethodLimit = 1500 264 308 27 2.5 402.2 2.0X +codegen = F 567 620 37 1.2 864.6 1.0X +codegen = T, hugeMethodLimit = 10000 283 316 26 2.3 431.9 2.0X +codegen = T, hugeMethodLimit = 1500 275 324 40 2.4 420.2 2.1X ================================================================================================ cube ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz cube: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cube wholestage off 2963 3099 193 1.8 565.1 1.0X -cube wholestage on 1624 1767 98 3.2 309.8 1.8X +cube wholestage off 3389 3476 123 1.5 646.4 1.0X +cube wholestage on 1692 1726 34 3.1 322.7 2.0X ================================================================================================ hash and BytesToBytesMap ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.11+9-LTS on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) CPU E5-2673 v3 @ 2.40GHz BytesToBytesMap: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UnsafeRowhash 247 268 19 84.8 11.8 1.0X -murmur3 hash 99 123 40 211.3 4.7 2.5X -fast hash 56 66 5 374.0 2.7 4.4X -arrayEqual 186 200 8 113.0 8.8 1.3X -Java HashMap (Long) 121 207 65 173.5 5.8 2.0X -Java HashMap (two ints) 147 233 61 142.8 7.0 1.7X -Java HashMap (UnsafeRow) 733 778 45 28.6 34.9 0.3X -LongToUnsafeRowMap (opt=false) 489 504 15 42.8 23.3 0.5X -LongToUnsafeRowMap (opt=true) 125 154 29 168.2 5.9 2.0X -BytesToBytesMap (off Heap) 840 895 48 25.0 40.1 0.3X -BytesToBytesMap (on Heap) 853 904 60 24.6 40.7 0.3X -Aggregate HashMap 38 46 8 546.3 1.8 6.4X +UnsafeRowhash 302 306 4 69.5 14.4 1.0X +murmur3 hash 125 129 3 167.8 6.0 2.4X +fast hash 69 73 3 304.1 3.3 4.4X +arrayEqual 192 195 3 109.0 9.2 1.6X +Java HashMap (Long) 133 187 53 157.2 6.4 2.3X +Java HashMap (two ints) 156 230 62 134.3 7.4 1.9X +Java HashMap (UnsafeRow) 807 812 6 26.0 38.5 0.4X +LongToUnsafeRowMap (opt=false) 502 529 24 41.8 23.9 0.6X +LongToUnsafeRowMap (opt=true) 148 164 20 141.7 7.1 2.0X +BytesToBytesMap (off Heap) 936 950 23 22.4 44.6 0.3X +BytesToBytesMap (on Heap) 954 956 2 22.0 45.5 0.3X +Aggregate HashMap 46 54 11 455.4 2.2 6.6X diff --git a/sql/core/benchmarks/AggregateBenchmark-results.txt b/sql/core/benchmarks/AggregateBenchmark-results.txt index f18c470831..d4de806d03 100644 --- a/sql/core/benchmarks/AggregateBenchmark-results.txt +++ b/sql/core/benchmarks/AggregateBenchmark-results.txt @@ -2,142 +2,147 @@ aggregate without grouping ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz agg w/o group: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -agg w/o group wholestage off 47798 50190 NaN 43.9 22.8 1.0X -agg w/o group wholestage on 1091 1128 28 1922.6 0.5 43.8X +agg w/o group wholestage off 53440 63455 NaN 39.2 25.5 1.0X +agg w/o group wholestage on 1157 1216 39 1812.5 0.6 46.2X ================================================================================================ stat functions ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz stddev: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -stddev wholestage off 7884 7959 106 13.3 75.2 1.0X -stddev wholestage on 1012 1072 34 103.6 9.6 7.8X +stddev wholestage off 7920 7947 39 13.2 75.5 1.0X +stddev wholestage on 1147 1160 11 91.4 10.9 6.9X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz kurtosis: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -kurtosis wholestage off 34023 34576 783 3.1 324.5 1.0X -kurtosis wholestage on 1092 1121 30 96.1 10.4 31.2X +kurtosis wholestage off 35143 35319 250 3.0 335.1 1.0X +kurtosis wholestage on 1239 1258 20 84.6 11.8 28.4X ================================================================================================ aggregate with linear keys ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 9309 9379 99 9.0 111.0 1.0X -codegen = T hashmap = F 5453 5643 223 15.4 65.0 1.7X -codegen = T hashmap = T 1084 1110 16 77.4 12.9 8.6X +codegen = F 9147 9183 50 9.2 109.0 1.0X +codegen = T, hashmap = F 5794 5949 226 14.5 69.1 1.6X +codegen = T, row-based hashmap = T 1378 1397 14 60.9 16.4 6.6X +codegen = T, vectorized hashmap = T 996 1034 25 84.3 11.9 9.2X ================================================================================================ aggregate with randomized keys ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Aggregate w keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 10707 10950 344 7.8 127.6 1.0X -codegen = T hashmap = F 7295 7423 145 11.5 87.0 1.5X -codegen = T hashmap = T 2057 2199 199 40.8 24.5 5.2X +codegen = F 9356 9425 98 9.0 111.5 1.0X +codegen = T, hashmap = F 5787 5912 176 14.5 69.0 1.6X +codegen = T, row-based hashmap = T 2569 2602 49 32.7 30.6 3.6X +codegen = T, vectorized hashmap = T 2094 2128 27 40.1 25.0 4.5X ================================================================================================ aggregate with string key ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Aggregate w string key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 4570 4573 4 4.6 217.9 1.0X -codegen = T hashmap = F 3600 3686 74 5.8 171.7 1.3X -codegen = T hashmap = T 2384 2432 45 8.8 113.7 1.9X +codegen = F 4270 4322 75 4.9 203.6 1.0X +codegen = T, hashmap = F 3241 3264 30 6.5 154.6 1.3X +codegen = T, row-based hashmap = T 2196 2247 32 9.6 104.7 1.9X +codegen = T, vectorized hashmap = T 2291 2306 14 9.2 109.3 1.9X ================================================================================================ aggregate with decimal key ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Aggregate w decimal key: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 2966 3011 64 7.1 141.4 1.0X -codegen = T hashmap = F 1857 1908 73 11.3 88.5 1.6X -codegen = T hashmap = T 695 702 8 30.2 33.2 4.3X +codegen = F 2993 3010 23 7.0 142.7 1.0X +codegen = T, hashmap = F 1940 1945 7 10.8 92.5 1.5X +codegen = T, row-based hashmap = T 738 752 20 28.4 35.2 4.1X +codegen = T, vectorized hashmap = T 620 650 21 33.8 29.6 4.8X ================================================================================================ aggregate with multiple key types ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Aggregate w multiple keys: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 7361 7385 35 2.8 351.0 1.0X -codegen = T hashmap = F 4525 4688 231 4.6 215.8 1.6X -codegen = T hashmap = T 3865 3977 159 5.4 184.3 1.9X +codegen = F 6635 6636 2 3.2 316.4 1.0X +codegen = T, hashmap = F 4236 4269 47 5.0 202.0 1.6X +codegen = T, row-based hashmap = T 3118 3158 57 6.7 148.7 2.1X +codegen = T, vectorized hashmap = T 3259 3278 27 6.4 155.4 2.0X ================================================================================================ max function bytecode size of wholestagecodegen ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz max function bytecode size: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -codegen = F 451 489 23 1.5 688.5 1.0X -codegen = T hugeMethodLimit = 10000 211 229 19 3.1 322.4 2.1X -codegen = T hugeMethodLimit = 1500 203 226 20 3.2 309.5 2.2X +codegen = F 467 492 33 1.4 712.4 1.0X +codegen = T, hugeMethodLimit = 10000 216 231 19 3.0 329.7 2.2X +codegen = T, hugeMethodLimit = 1500 209 221 9 3.1 319.0 2.2X ================================================================================================ cube ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz cube: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -cube wholestage off 2479 2548 97 2.1 472.9 1.0X -cube wholestage on 1487 1567 62 3.5 283.7 1.7X +cube wholestage off 2490 2529 56 2.1 474.8 1.0X +cube wholestage on 1401 1416 22 3.7 267.3 1.8X ================================================================================================ hash and BytesToBytesMap ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 1.8.0_292-b10 on Linux 5.4.0-1046-azure +Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz BytesToBytesMap: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -UnsafeRowhash 826 837 16 25.4 39.4 1.0X -murmur3 hash 537 553 11 39.1 25.6 1.5X -fast hash 559 572 14 37.5 26.6 1.5X -arrayEqual 1665 1728 90 12.6 79.4 0.5X -Java HashMap (Long) 732 739 7 28.7 34.9 1.1X -Java HashMap (two ints) 682 694 15 30.7 32.5 1.2X -Java HashMap (UnsafeRow) 1486 1499 19 14.1 70.9 0.6X -LongToUnsafeRowMap (opt=false) 1235 1240 8 17.0 58.9 0.7X -LongToUnsafeRowMap (opt=true) 718 736 17 29.2 34.2 1.2X -BytesToBytesMap (off Heap) 945 965 20 22.2 45.1 0.9X -BytesToBytesMap (on Heap) 870 895 28 24.1 41.5 0.9X -Aggregate HashMap 64 71 5 325.6 3.1 12.8X +UnsafeRowhash 259 264 5 81.0 12.3 1.0X +murmur3 hash 113 121 3 185.7 5.4 2.3X +fast hash 84 87 2 249.8 4.0 3.1X +arrayEqual 172 180 4 121.9 8.2 1.5X +Java HashMap (Long) 155 161 5 135.2 7.4 1.7X +Java HashMap (two ints) 147 157 8 142.6 7.0 1.8X +Java HashMap (UnsafeRow) 739 742 4 28.4 35.2 0.4X +LongToUnsafeRowMap (opt=false) 489 491 3 42.9 23.3 0.5X +LongToUnsafeRowMap (opt=true) 93 100 6 224.8 4.4 2.8X +BytesToBytesMap (off Heap) 882 896 16 23.8 42.1 0.3X +BytesToBytesMap (on Heap) 833 863 36 25.2 39.7 0.3X +Aggregate HashMap 66 69 1 317.0 3.2 3.9X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 829025f3dc..0795776eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -214,11 +214,13 @@ class QueryExecution( QueryPlan.append(logical, append, verbose, addSuffix, maxFields) append("\n== Analyzed Logical Plan ==\n") try { - append( - truncatedString( - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields) - ) - append("\n") + if (analyzed.output.nonEmpty) { + append( + truncatedString( + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ", maxFields) + ) + append("\n") + } QueryPlan.append(analyzed, append, verbose, addSuffix, maxFields) append("\n== Optimized Logical Plan ==\n") QueryPlan.append(optimizedPlan, append, verbose, addSuffix, maxFields) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index bd45863652..d50e32c8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.SinglePartition -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.internal.SQLConf @@ -54,8 +54,21 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) { plan } else { + def insertCustomShuffleReader(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = { + // This transformation adds new nodes, so we must use `transformUp` here. + val stageIds = shuffleStages.map(_.id).toSet + plan.transformUp { + // even for shuffle exchange whose input RDD has 0 partition, we should still update its + // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same + // number of output partitions. + case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => + CustomShuffleReaderExec(stage, partitionSpecs) + } + } + // `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions, // we should skip it when calculating the `partitionStartIndices`. + // If all input RDDs have 0 partition, we create empty partition for every shuffle reader. val validMetrics = shuffleStages.flatMap(_.mapStats) // We may have different pre-shuffle partition numbers, don't reduce shuffle partition number @@ -63,7 +76,9 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl // partition) and a result of a SortMergeJoin (multiple partitions). val distinctNumPreShufflePartitions = validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + if (validMetrics.isEmpty) { + insertCustomShuffleReader(ShufflePartitionsUtil.createEmptyPartition() :: Nil) + } else if (distinctNumPreShufflePartitions.length == 1) { // We fall back to Spark default parallelism if the minimum number of coalesced partitions // is not set, so to avoid perf regressions compared to no coalescing. val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM) @@ -77,15 +92,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl if (partitionSpecs.length == distinctNumPreShufflePartitions.head) { plan } else { - // This transformation adds new nodes, so we must use `transformUp` here. - val stageIds = shuffleStages.map(_.id).toSet - plan.transformUp { - // even for shuffle exchange whose input RDD has 0 partition, we should still update its - // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same - // number of output partitions. - case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => - CustomShuffleReaderExec(stage, partitionSpecs) - } + insertCustomShuffleReader(partitionSpecs) } } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index d98b7c29a3..1065519256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.v2.V2CommandExec @@ -113,6 +114,9 @@ case class InsertAdaptiveSparkPlan( */ private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = { val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec] + if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { + return subqueryMap.toMap + } plan.foreach(_.expressions.foreach(_.foreach { case expressions.ScalarSubquery(p, _, exprId) if !subqueryMap.contains(exprId.id) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 13ff236d20..a2e4397a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, + SCALAR_SUBQUERY} import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan} @@ -27,7 +29,8 @@ case class PlanAdaptiveSubqueries( subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - plan.transformAllExpressions { + plan.transformAllExpressionsWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { case expressions.ScalarSubquery(_, _, exprId) => execution.ScalarSubquery(subqueryMap(exprId.id), exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index ed92af6adc..a70a5322a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -125,6 +125,10 @@ object ShufflePartitionsUtil extends Logging { partitionSpecs.toSeq } + def createEmptyPartition(): ShufflePartitionSpec = { + CoalescedPartitionSpec(0, 0) + } + /** * Given a list of size, return an array of indices to split the list into multiple partitions, * so that the size sum of each partition is close to the target size. Each index indicates the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6e23a2844d..3c1304e9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit._ import scala.collection.mutable import org.apache.spark.TaskContext -import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.memory.SparkOutOfMemoryError import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -435,8 +435,8 @@ case class HashAggregateExec( ) } - def getTaskMemoryManager(): TaskMemoryManager = { - TaskContext.get().taskMemoryManager() + def getTaskContext(): TaskContext = { + TaskContext.get() } def getEmptyAggregationBuffer(): InternalRow = { @@ -647,7 +647,7 @@ case class HashAggregateExec( (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] || f.dataType.isInstanceOf[CalendarIntervalType]) && - bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) + bufferSchema.nonEmpty // For vectorized hash map, We do not support byte array based decimal type for aggregate values // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place @@ -663,7 +663,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { - if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { + if (!Utils.isTesting) { logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } @@ -683,7 +683,18 @@ case class HashAggregateExec( } else if (sqlContext.conf.enableVectorizedHashMap) { logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") } - val bitMaxCapacity = sqlContext.conf.fastHashAggregateRowMaxCapacityBit + val bitMaxCapacity = testFallbackStartsAt match { + case Some((fastMapCounter, _)) => + // In testing, with fall back counter of fast hash map (`fastMapCounter`), set the max bit + // of map to be no more than log2(`fastMapCounter`). This helps control the number of keys + // in map to mimic fall back. + if (fastMapCounter <= 1) { + 0 + } else { + (math.log10(fastMapCounter) / math.log10(2)).floor.toInt + } + case _ => sqlContext.conf.fastHashAggregateRowMaxCapacityBit + } val thisPlan = ctx.addReferenceObj("plan", this) @@ -717,11 +728,28 @@ case class HashAggregateExec( "org.apache.spark.unsafe.KVIterator", "fastHashMapIter", forceInline = true) val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" + s"$thisPlan.getTaskContext().taskMemoryManager(), " + + s"$thisPlan.getEmptyAggregationBuffer());" (iter, create) } } else ("", "") + // Generates the code to register a cleanup task with TaskContext to ensure that memory + // is guaranteed to be freed at the end of the task. This is necessary to avoid memory + // leaks in when the downstream operator does not fully consume the aggregation map's + // output (e.g. aggregate followed by limit). + val addHookToCloseFastHashMap = if (isFastHashMapEnabled) { + s""" + |$thisPlan.getTaskContext().addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | $fastHashMapTerm.close(); + | } + |}); + """.stripMargin + } else "" + // Create a name for the iterator from the regular hash map. // Inline mutable state since not many aggregation operations in a task val iterTerm = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, @@ -761,6 +789,8 @@ case class HashAggregateExec( val bufferTerm = ctx.freshName("aggBuffer") val outputFunc = generateResultFunction(ctx) + val limitNotReachedCondition = limitNotReachedCond + def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { @@ -773,7 +803,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - |while ($iterTermForFastHashMap.next()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -798,7 +828,7 @@ case class HashAggregateExec( BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" - |while ($iterTermForFastHashMap.hasNext()) { + |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} @@ -813,7 +843,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" - |while ($limitNotReachedCond $iterTerm.next()) { + |while ($limitNotReachedCondition $iterTerm.next()) { | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -832,6 +862,7 @@ case class HashAggregateExec( |if (!$initAgg) { | $initAgg = true; | $createFastHashMap + | $addHookToCloseFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); @@ -866,13 +897,11 @@ case class HashAggregateExec( } } - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, - incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") - (s"$countTerm < ${testFallbackStartsAt.get._1}", - s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") - } else { - ("true", "true", "", "") + val (checkFallbackForBytesToBytesMap, resetCounter, incCounter) = testFallbackStartsAt match { + case Some((_, regularMapCounter)) => + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") + (s"$countTerm < $regularMapCounter", s"$countTerm = 0;", s"$countTerm += 1;") + case _ => ("true", "", "") } val oomeClassName = classOf[SparkOutOfMemoryError].getName @@ -912,12 +941,10 @@ case class HashAggregateExec( // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); |} |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index d3a02f2451..fcae7ac32b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -74,6 +75,8 @@ object ArrowWriter { } new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) + case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) + case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector) case (dt, _) => throw QueryExecutionErrors.unsupportedDataTypeError(dt) } @@ -394,3 +397,28 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { } } + +private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector) + extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)); + } +} + +private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector) + extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val totalMicroseconds = input.getLong(ordinal) + val days = totalMicroseconds / MICROS_PER_DAY + val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS + valueVector.set(count, days.toInt, millis.toInt) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index c6b6a21da5..c32d1d74c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -139,7 +139,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") Seq(Row( SQLConf.SHUFFLE_PARTITIONS.key, - sparkSession.sessionState.conf.numShufflePartitions.toString)) + sparkSession.sessionState.conf.defaultNumShufflePartitions.toString)) } (keyValueOutput, runFunc) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b6b07de8a5..4f60a9d4c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -53,11 +53,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty private[this] var numFiles: Int = 0 - private[this] var submittedFiles: Int = 0 + private[this] var numSubmittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: Option[String] = None + private[this] val submittedFiles = mutable.HashSet[String]() /** * Get the size of the file expected to have been written by a worker. @@ -134,23 +134,20 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) partitions.append(partitionValues) } - override def newBucket(bucketId: Int): Unit = { - // currently unhandled + override def newFile(filePath: String): Unit = { + submittedFiles += filePath + numSubmittedFiles += 1 } - override def newFile(filePath: String): Unit = { - statCurrentFile() - curFile = Some(filePath) - submittedFiles += 1 + override def closeFile(filePath: String): Unit = { + updateFileStats(filePath) + submittedFiles.remove(filePath) } - private def statCurrentFile(): Unit = { - curFile.foreach { path => - getFileSize(path).foreach { len => - numBytes += len - numFiles += 1 - } - curFile = None + private def updateFileStats(filePath: String): Unit = { + getFileSize(filePath).foreach { len => + numBytes += len + numFiles += 1 } } @@ -159,7 +156,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - statCurrentFile() + submittedFiles.foreach(updateFileStats) + submittedFiles.clear() // Reports bytesWritten and recordsWritten to the Spark output metrics. Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics => @@ -167,8 +165,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) outputMetrics.setRecordsWritten(numRows) } - if (submittedFiles != numFiles) { - logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + if (numSubmittedFiles != numFiles) { + logInfo(s"Expected $numSubmittedFiles files, but only saw $numFiles. " + "This could be due to the output format not writing empty files, " + "or files being not immediately visible in the filesystem.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 6de9b1d7ce..8230737a61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow @@ -28,6 +29,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -52,19 +55,35 @@ abstract class FileFormatDataWriter( protected val statsTrackers: Seq[WriteTaskStatsTracker] = description.statsTrackers.map(_.newTaskInstance()) - protected def releaseResources(): Unit = { + /** Release resources of `currentWriter`. */ + protected def releaseCurrentWriter(): Unit = { if (currentWriter != null) { try { currentWriter.close() + statsTrackers.foreach(_.closeFile(currentWriter.path())) } finally { currentWriter = null } } } - /** Writes a record */ + /** Release all resources. */ + protected def releaseResources(): Unit = { + // Call `releaseCurrentWriter()` by default, as this is the only resource to be released. + releaseCurrentWriter() + } + + /** Writes a record. */ def write(record: InternalRow): Unit + + /** Write an iterator of records. */ + def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext) { + write(iterator.next()) + } + } + /** * Returns the summary of relative information which * includes the list of partition strings written out. The list of partitions is sent back @@ -144,34 +163,38 @@ class SingleDirectoryDataWriter( } /** - * Writes data to using dynamic partition writes, meaning this single function can write to + * Holds common logic for writing data with dynamic partition writes, meaning it can write to * multiple directories (partitions) or files (bucketing). */ -class DynamicPartitionDataWriter( +abstract class BaseDynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends FileFormatDataWriter(description, taskAttemptContext, committer) { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = description.partitionColumns.nonEmpty + protected val isPartitioned = description.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = description.bucketIdExpression.isDefined + protected val isBucketed = description.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: $description + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description """.stripMargin) - private var fileCounter: Int = _ - private var recordsInFile: Long = _ - private var currentPartitionValues: Option[UnsafeRow] = None - private var currentBucketId: Option[Int] = None + /** Number of records in current file. */ + protected var recordsInFile: Long = _ + + /** + * File counter for writing current partition or bucket. For same partition or bucket, + * we may have more than one file, due to number of records limit per file. + */ + protected var fileCounter: Int = _ /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { + protected lazy val getPartitionValues: InternalRow => UnsafeRow = { val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) row => proj(row) } @@ -186,22 +209,24 @@ class DynamicPartitionDataWriter( if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ + /** + * Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. + */ private lazy val getPartitionPath: InternalRow => String = { val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) row => proj(row).getString(0) } /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { + protected lazy val getBucketId: InternalRow => Int = { val proj = UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) row => proj(row).getInt(0) } /** Returns the data columns to be written given an input row */ - private val getOutputRow = + protected val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) /** @@ -209,13 +234,20 @@ class DynamicPartitionDataWriter( * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * @param partitionValues the partition which all tuples being written by this OutputWriter * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + * @param bucketId the bucket which all tuples being written by this OutputWriter belong to + * @param closeCurrentWriter close and release resource for current writer */ - private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + protected def renewCurrentWriter( + partitionValues: Option[InternalRow], + bucketId: Option[Int], + closeCurrentWriter: Boolean): Unit = { + recordsInFile = 0 - releaseResources() + if (closeCurrentWriter) { + releaseCurrentWriter() + } val partDir = partitionValues.map(getPartitionPath(_)) partDir.foreach(updatedPartitions.add) @@ -243,6 +275,51 @@ class DynamicPartitionDataWriter( statsTrackers.foreach(_.newFile(currentPath)) } + /** + * Open a new output writer when number of records exceeding limit. + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + protected def renewCurrentWriterIfTooManyRecords( + partitionValues: Option[InternalRow], + bucketId: Option[Int]): Unit = { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + renewCurrentWriter(partitionValues, bucketId, closeCurrentWriter = true) + } + + /** + * Writes the given record with current writer. + * + * @param record The record to write + */ + protected def writeRecord(record: InternalRow): Unit = { + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** + * Dynamic partition writer with single writer, meaning only one writer is opened at any time for + * writing. The records to be written are required to be sorted on partition and/or bucket + * column(s) before writing. + */ +class DynamicPartitionDataSingleWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + + private var currentPartitionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + override def write(record: InternalRow): Unit = { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None @@ -255,25 +332,199 @@ class DynamicPartitionDataWriter( } if (isBucketed) { currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) } fileCounter = 0 - newOutputWriter(currentPartitionValues, currentBucketId) + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + } + writeRecord(record) + } +} + +/** + * Dynamic partition writer with concurrent writers, meaning multiple concurrent writers are opened + * for writing. + * + * The process has the following steps: + * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. Keep all + * writers opened and write rows one by one. + * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows on partition and/or + * bucket column(s). Write rows one by one, and eagerly close the writer when finishing + * each partition and/or bucket. + * + * Caller is expected to call `writeWithIterator()` instead of `write()` to write records. + */ +class DynamicPartitionDataConcurrentWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) + with Logging { + + /** Wrapper class to index a unique concurrent output writer. */ + private case class WriterIndex( + var partitionValues: Option[UnsafeRow], + var bucketId: Option[Int]) + + /** Wrapper class for status of a unique concurrent output writer. */ + private class WriterStatus( + var outputWriter: OutputWriter, + var recordsInFile: Long, + var fileCounter: Int) + + /** + * State to indicate if we are falling back to sort-based writer. + * Because we first try to use concurrent writers, its initial value is false. + */ + private var sorted: Boolean = false + private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]() + + /** + * The index for current writer. Intentionally make the index mutable and reusable. + * Avoid JVM GC issue when many short-living `WriterIndex` objects are created + * if switching between concurrent writers frequently. + */ + private val currentWriterId = WriterIndex(None, None) + + /** + * Release resources for all concurrent output writers. + */ + override protected def releaseResources(): Unit = { + currentWriter = null + concurrentWriters.values.foreach(status => { + if (status.outputWriter != null) { + try { + status.outputWriter.close() + } finally { + status.outputWriter = null + } + } + }) + concurrentWriters.clear() + } - newOutputWriter(currentPartitionValues, currentBucketId) + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentWriterId.partitionValues != nextPartitionValues || + currentWriterId.bucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (currentWriter != null) { + if (!sorted) { + // Update writer status in concurrent writers map, because the writer is probably needed + // again later for writing other rows. + updateCurrentWriterStatusInMap() + } else { + // Remove writer status in concurrent writers map and release current writer resource, + // because the writer is not needed any more. + concurrentWriters.remove(currentWriterId) + releaseCurrentWriter() + } + } + + if (isBucketed) { + currentWriterId.bucketId = nextBucketId + } + if (isPartitioned && currentWriterId.partitionValues != nextPartitionValues) { + currentWriterId.partitionValues = Some(nextPartitionValues.get.copy()) + if (!concurrentWriters.contains(currentWriterId)) { + statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get)) + } + } + setupCurrentWriterUsingMap() } - val outputRow = getOutputRow(record) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 + + if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + renewCurrentWriterIfTooManyRecords(currentWriterId.partitionValues, currentWriterId.bucketId) + // Update writer status in concurrent writers map, as a new writer is created. + updateCurrentWriterStatusInMap() + } + writeRecord(record) + } + + /** + * Write iterator of records with concurrent writers. + */ + override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext && !sorted) { + write(iterator.next()) + } + + if (iterator.hasNext) { + clearCurrentWriterStatus() + val sorter = concurrentOutputWriterSpec.createSorter() + val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]]) + while (sortIterator.hasNext) { + write(sortIterator.next()) + } + } + } + + /** + * Update current writer status in map. + */ + private def updateCurrentWriterStatusInMap(): Unit = { + val status = concurrentWriters(currentWriterId) + status.outputWriter = currentWriter + status.recordsInFile = recordsInFile + status.fileCounter = fileCounter + } + + /** + * Retrieve writer in map, or create a new writer if not exists. + */ + private def setupCurrentWriterUsingMap(): Unit = { + if (concurrentWriters.contains(currentWriterId)) { + val status = concurrentWriters(currentWriterId) + currentWriter = status.outputWriter + recordsInFile = status.recordsInFile + fileCounter = status.fileCounter + } else { + fileCounter = 0 + renewCurrentWriter( + currentWriterId.partitionValues, + currentWriterId.bucketId, + closeCurrentWriter = false) + if (!sorted) { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters, + s"Number of concurrent output file writers is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters}") + } else { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters + 1, + s"Number of output file writers after sort is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters + 1}") + } + concurrentWriters.put( + currentWriterId.copy(), + new WriterStatus(currentWriter, recordsInFile, fileCounter)) + if (concurrentWriters.size >= concurrentOutputWriterSpec.maxWriters && !sorted) { + // Fall back to sort-based sequential writer mode. + logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches the threshold. " + + "Fall back from concurrent writers to sort-based sequential writer. You may change " + + s"threshold with configuration ${SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key}") + sorted = true + } + } + } + + /** + * Clear the current writer status in map. + */ + private def clearCurrentWriterStatus(): Unit = { + if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { + updateCurrentWriterStatusInMap() + } + currentWriterId.partitionValues = None + currentWriterId.bucketId = None + currentWriter = null + recordsInFile = 0 + fileCounter = 0 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 6300e10c0b..6839a4db0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String @@ -73,6 +73,11 @@ object FileFormatWriter extends Logging { copy(child = newChild) } + /** Describes how concurrent output writers should be executed. */ + case class ConcurrentOutputWriterSpec( + maxWriters: Int, + createSorter: () => UnsafeExternalRowSorter) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -177,18 +182,27 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - empty2NullPlan.execute() + val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { + (empty2NullPlan.execute(), None) } else { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = bindReferences( requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) - SortExec( + val sortPlan = SortExec( orderingExpr, global = false, - child = empty2NullPlan).execute() + child = empty2NullPlan) + + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + (empty2NullPlan.execute(), + Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) + } else { + (sortPlan.execute(), None) + } } // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single @@ -211,7 +225,8 @@ object FileFormatWriter extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, committer, - iterator = iter) + iterator = iter, + concurrentOutputWriterSpec = concurrentOutputWriterSpec) }, rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { @@ -245,7 +260,8 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): WriteTaskResult = { + iterator: Iterator[InternalRow], + concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -273,15 +289,19 @@ object FileFormatWriter extends Logging { } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + concurrentOutputWriterSpec match { + case Some(spec) => + new DynamicPartitionDataConcurrentWriter( + description, taskAttemptContext, committer, spec) + case _ => + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + } } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - while (iterator.hasNext) { - dataWriter.write(iterator.next()) - } + dataWriter.writeWithIterator(iterator) dataWriter.commit() })(catchBlock = { // If there is an error, abort the task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index a0b191e60f..4ed8943ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -40,6 +40,8 @@ import org.apache.spark.sql.types.{StructField, StructType} case class HadoopFsRelation( location: FileIndex, partitionSchema: StructType, + // The top-level columns in `dataSchema` should match the actual physical file schema, otherwise + // the ORC data source may not work with the by-ordinal mode. dataSchema: StructType, bucketSpec: Option[BucketSpec], fileFormat: FileFormat, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index 1d7abe5b93..7c479d986f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -57,7 +57,7 @@ abstract class OutputWriterFactory extends Serializable { */ abstract class OutputWriter { /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ def write(row: InternalRow): Unit @@ -67,4 +67,9 @@ abstract class OutputWriter { * the task output is committed. */ def close(): Unit + + /** + * The file path to write. Invoked on the executor side. + */ + def path(): String } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index c39a82ee03..aaf866bced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -32,20 +32,7 @@ trait WriteTaskStats extends Serializable * A trait for classes that are capable of collecting statistics on data that's being processed by * a single write task in [[FileFormatWriter]] - i.e. there should be one instance per executor. * - * This trait is coupled with the way [[FileFormatWriter]] works, in the sense that its methods - * will be called according to how tuples are being written out to disk, namely in sorted order - * according to partitionValue(s), then bucketId. - * - * As such, a typical call scenario is: - * - * newPartition -> newBucket -> newFile -> newRow -. - * ^ |______^___________^ ^ ^____| - * | | |______________| - * | |____________________________| - * |____________________________________________| - * - * newPartition and newBucket events are only triggered if the relation to be written out is - * partitioned and/or bucketed, respectively. + * newPartition event is only triggered if the relation to be written out is partitioned. */ trait WriteTaskStatsTracker { @@ -56,22 +43,20 @@ trait WriteTaskStatsTracker { */ def newPartition(partitionValues: InternalRow): Unit - /** - * Process the fact that a new bucket is about to written. - * Only triggered when the relation is bucketed by a (non-empty) sequence of columns. - * @param bucketId The bucket number. - */ - def newBucket(bucketId: Int): Unit - /** * Process the fact that a new file is about to be written. * @param filePath Path of the file into which future rows will be written. */ def newFile(filePath: String): Unit + /** + * Process the fact that a file is finished to be written and closed. + * @param filePath Path of the file. + */ + def closeFile(filePath: String): Unit + /** * Process the fact that a new row to update the tracked statistics accordingly. - * The row will be written to the most recently witnessed file (via `newFile`). * @note Keep in mind that any overhead here is per-row, obviously, * so implementations should be as lightweight as possible. * @param row Current data row to be processed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 2b549536ae..35d0e098b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class CsvOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index 719d72f5b9..55602ce2ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class JsonOutputWriter( - path: String, + val path: String, options: JSONOptions, dataSchema: StructType, context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 08086bcd91..6f215737f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ private[sql] class OrcOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 70f6726c58..efb322f3fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -class ParquetOutputWriter(path: String, context: TaskAttemptContext) +class ParquetOutputWriter(val path: String, context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala index 2b1b81f60c..2fb37c0dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class TextOutputWriter( - path: String, + val path: String, dataSchema: StructType, lineSeparator: Array[Byte], context: TaskAttemptContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index 1f25fed300..d827e83623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} -import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, SingleDirectoryDataWriter, WriteJobDescription} +import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWriter, SingleDirectoryDataWriter, WriteJobDescription} case class FileWriterFactory ( description: WriteJobDescription, @@ -35,7 +35,7 @@ case class FileWriterFactory ( if (description.partitionColumns.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 167ba45b88..1f57f17911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -81,6 +81,10 @@ object PushDownUtils extends PredicateHelper { relation: DataSourceV2Relation, projects: Seq[NamedExpression], filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = { + val exprs = projects ++ filters + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + scanBuilder match { case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled => val rootFields = SchemaPruning.identifyRootFields(projects, filters) @@ -89,14 +93,12 @@ object PushDownUtils extends PredicateHelper { } else { new StructType() } - r.pruneColumns(prunedSchema) + val neededFieldNames = neededOutput.map(_.name).toSet + r.pruneColumns(StructType(prunedSchema.filter(f => neededFieldNames.contains(f.name)))) val scan = r.build() scan -> toOutputAttrs(scan.readSchema(), relation) case r: SupportsPushDownRequiredColumns => - val exprs = projects ++ filters - val requiredColumns = AttributeSet(exprs.flatMap(_.references)) - val neededOutput = relation.output.filter(requiredColumns.contains) r.pruneColumns(neededOutput.toStructType) val scan = r.build() // always project, in case the relation's output has been updated and doesn't match diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index d3bc4aed57..9a05e396d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins._ @@ -49,7 +50,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) return plan } - plan transformAllExpressions { + plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) { case DynamicPruningSubquery( value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => val sparkPlan = QueryExecution.createSparkPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala index 3cb20f87ae..f2449a1ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/CustomMetrics.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.metric -import java.text.NumberFormat -import java.util.Locale - -import org.apache.spark.sql.connector.CustomMetric +import org.apache.spark.sql.connector.metric.CustomMetric object CustomMetrics { private[spark] val V2_CUSTOM = "v2Custom" @@ -45,35 +42,3 @@ object CustomMetrics { } } } - -/** - * Built-in `CustomMetric` that sums up metric values. - */ -class CustomSumMetric extends CustomMetric { - override def name(): String = "CustomSumMetric" - - override def description(): String = "Sum up CustomMetric" - - override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - taskMetrics.sum.toString - } -} - -/** - * Built-in `CustomMetric` that computes average of metric values. - */ -class CustomAvgMetric extends CustomMetric { - override def name(): String = "CustomAvgMetric" - - override def description(): String = "Average CustomMetric" - - override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - val average = if (taskMetrics.isEmpty) { - 0.0 - } else { - taskMetrics.sum.toDouble / taskMetrics.length - } - val numberFormat = NumberFormat.getNumberInstance(Locale.US) - numberFormat.format(average) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index da39e8c455..959144bab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -24,7 +24,7 @@ import scala.concurrent.duration._ import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.sql.connector.CustomMetric +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} @@ -113,7 +113,7 @@ object SQLMetrics { */ def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = { val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric)) - acc.register(sc, name = Some(customMetric.name()), countFailedValues = false) + acc.register(sc, name = Some(customMetric.description()), countFailedValues = false) acc } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1c018be6d5..c6a70fb204 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -40,7 +40,6 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(e) || - e.isInstanceOf[GroupingExprRef] || agg.groupingExpressions.exists(_.semanticEquals(e)) } @@ -120,8 +119,23 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] { groupingExpr += expr } } + val aggExpr = agg.aggregateExpressions.map { expr => + expr.transformUp { + // PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate. + // PythonUDF here should be either + // 1. Argument of an aggregate function. + // CheckAnalysis guarantees the arguments are deterministic. + // 2. PythonUDF in grouping key. Grouping key must be deterministic. + // 3. PythonUDF not in grouping key. It is either no arguments or with grouping key + // in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too. + case p: PythonUDF if p.udfDeterministic => + val canonicalized = p.canonicalized.asInstanceOf[PythonUDF] + attributeMap.getOrElse(canonicalized, p) + }.asInstanceOf[NamedExpression] + } agg.copy( groupingExpressions = groupingExpr.toSeq, + aggregateExpressions = aggExpr, child = Project((projList ++ agg.child.output).toSeq, agg.child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 15b85013c4..f96e9ee3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} @@ -176,7 +177,7 @@ case class InSubqueryExec( */ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - plan.transformAllExpressions { + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) { case subquery: expressions.ScalarSubquery => val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan) ScalarSubquery( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index a3238551b2..e7ab4a184b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -28,7 +28,7 @@ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Status._ import org.apache.spark.scheduler._ -import org.apache.spark.sql.connector.CustomMetric +import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala index ed1d24b682..b7f3dec224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType} +import org.apache.spark.sql.types.{CalendarIntervalType, DateType, DayTimeIntervalType, IntegerType, TimestampType, YearMonthIntervalType} trait WindowExecBase extends UnaryExecNode { def windowExpression: Seq[NamedExpression] @@ -95,8 +95,11 @@ trait WindowExecBase extends UnaryExecNode { // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) - case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(timeZone)) + case (DateType, YearMonthIntervalType) => DateAddYMInterval(expr, boundOffset) + case (TimestampType, CalendarIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone)) + case (TimestampType, YearMonthIntervalType) => + TimestampAddYMInterval(expr, boundOffset, Some(timeZone)) + case (TimestampType, DayTimeIntervalType) => TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a == b => Add(expr, boundOffset) } val bound = MutableProjection.create(boundExpr :: Nil, child.output) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java index e418958bef..59c5263563 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameWriterV2Suite.java @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; -import org.apache.spark.sql.connector.InMemoryTableCatalog; +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; import org.junit.After; diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql index 0f1fd5bbcc..31603fba99 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/extract.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -128,3 +128,34 @@ select c - i from t; select year(c - i) from t; select extract(year from c - i) from t; select extract(month from to_timestamp(c) - i) from t; + +-- extract fields from year-month/day-time intervals +select extract(YEAR from interval '2-1' YEAR TO MONTH); +select date_part('YEAR', interval '2-1' YEAR TO MONTH); +select extract(YEAR from -interval '2-1' YEAR TO MONTH); +select extract(MONTH from interval '2-1' YEAR TO MONTH); +select date_part('MONTH', interval '2-1' YEAR TO MONTH); +select extract(MONTH from -interval '2-1' YEAR TO MONTH); +select date_part(NULL, interval '2-1' YEAR TO MONTH); + +-- invalid +select extract(DAY from interval '2-1' YEAR TO MONTH); +select date_part('DAY', interval '2-1' YEAR TO MONTH); +select date_part('not_supported', interval '2-1' YEAR TO MONTH); + +select extract(DAY from interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part('DAY', interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(DAY from -interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(HOUR from interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part('HOUR', interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(HOUR from -interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(MINUTE from interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part('MINUTE', interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(MINUTE from -interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(SECOND from interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part('SECOND', interval '123 12:34:56.789123123' DAY TO SECOND); +select extract(SECOND from -interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part(NULL, interval '123 12:34:56.789123123' DAY TO SECOND); + +select extract(MONTH from interval '123 12:34:56.789123123' DAY TO SECOND); +select date_part('not_supported', interval '123 12:34:56.789123123' DAY TO SECOND); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 6dfe31e270..d6381e59e0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -80,3 +80,14 @@ SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), ()); SELECT a, b, count(1) FROM testData GROUP BY a, CUBE(a, b), GROUPING SETS((a, b), (a), ()); SELECT a, b, count(1) FROM testData GROUP BY a, CUBE(a, b), ROLLUP(a, b), GROUPING SETS((a, b), (a), ()); +-- Support nested CUBE/ROLLUP/GROUPING SETS in GROUPING SETS +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b)); +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ())); + +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b))); +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b)); +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b))); + +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b)); +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ())); +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ()); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 988ad99418..6ee1014739 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -179,12 +179,3 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max( -- Aggregate with multiple distinct decimal columns SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col); - --- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function -SELECT not(a IS NULL), count(*) AS c -FROM testData -GROUP BY a IS NULL; - -SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c -FROM testData -GROUP BY a IS NULL; diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index 3fcbdacda6..063727a76e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -102,6 +102,10 @@ select interval 30 day day day; select interval (-30) days; select interval (a + 1) days; select interval 30 days days days; +SELECT INTERVAL '178956970-7' YEAR TO MONTH; +SELECT INTERVAL '178956970-8' YEAR TO MONTH; +SELECT INTERVAL '-178956970-8' YEAR TO MONTH; +SELECT INTERVAL -'178956970-8' YEAR TO MONTH; -- Interval year-month arithmetic @@ -218,3 +222,17 @@ select interval '1 day 1'; select interval '1 day 2' day; select interval 'interval 1' day; select interval '-\t 1' day; + +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2; +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5; +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1; +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L; +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0; +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D; + +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2; +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5; +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1; +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L; +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0; +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D; diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 7419ca1bd0..d84659c4cc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -11,6 +11,18 @@ CREATE OR REPLACE TEMPORARY VIEW script_trans AS SELECT * FROM VALUES (7, 8, 9) AS script_trans(a, b, c); +CREATE OR REPLACE TEMPORARY VIEW complex_trans AS SELECT * FROM VALUES +(1, 1), +(1, 1), +(2, 2), +(2, 2), +(3, 3), +(2, 2), +(3, 3), +(1, 1), +(3, 3) +as complex_trans(a, b); + SELECT TRANSFORM(a) USING 'cat' AS (a) FROM t; @@ -342,3 +354,22 @@ SELECT TRANSFORM(b, MAX(a) AS max_a, CAST(sum(c) AS STRING)) FROM script_trans WHERE a <= 2 GROUP BY b; + +-- SPARK-33985: TRANSFORM with CLUSTER BY/ORDER BY/SORT BY +FROM ( + SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) + FROM complex_trans + CLUSTER BY a +) map_output +SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b); + +FROM ( + SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) + FROM complex_trans + ORDER BY a +) map_output +SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index 56f2b0b20c..46d3629a5d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -70,6 +70,18 @@ RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM testData ORDER BY cate, val_date SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp RANGE BETWEEN CURRENT ROW AND interval 23 days 4 hours FOLLOWING) FROM testData ORDER BY cate, val_timestamp; +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData +ORDER BY cate, val_timestamp; +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData +ORDER BY cate, val_timestamp; +SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData +ORDER BY cate, val_date; +SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData +ORDER BY cate, val_date; -- RangeBetween with reverse OrderBy SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out index 1db8febb81..9dbfc4cf4f 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 44 +-- Number of queries: 52 -- !query @@ -1067,3 +1067,227 @@ struct 3 NULL 2 3 NULL 2 3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ())) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b))) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b))) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ())) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ()) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 781a7d739c..e383fc1b85 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 118 +-- Number of queries: 134 -- !query @@ -780,6 +780,44 @@ select interval 30 days days days -----------------------------^^^ +-- !query +SELECT INTERVAL '178956970-7' YEAR TO MONTH +-- !query schema +struct +-- !query output +178956970-7 + + +-- !query +SELECT INTERVAL '178956970-8' YEAR TO MONTH +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Error parsing interval year-month string: integer overflow(line 1, pos 16) + +== SQL == +SELECT INTERVAL '178956970-8' YEAR TO MONTH +----------------^^^ + + +-- !query +SELECT INTERVAL '-178956970-8' YEAR TO MONTH +-- !query schema +struct +-- !query output +-178956970-8 + + +-- !query +SELECT INTERVAL -'178956970-8' YEAR TO MONTH +-- !query schema +struct +-- !query output +-178956970-8 + + -- !query create temporary view interval_arithmetic as select CAST(dateval AS date), CAST(tsval AS timestamp), dateval as strval from values @@ -1221,3 +1259,107 @@ select interval '-\t 1' day struct -- !query output -1 days + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2 +-- !query schema +struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 2):year-month interval> +-- !query output +-89478485-4 + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5 +-- !query schema +struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 5):year-month interval> +-- !query output +-35791394-2 + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +not in range + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2 +-- !query schema +struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 2):day-time interval> +-- !query output +-53375995 14:00:27.387904000 + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5 +-- !query schema +struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 5):day-time interval> +-- !query output +-21350398 05:36:10.955162000 + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +not in range diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out index 35cfda1767..63b5caac48 100644 --- a/sql/core/src/test/resources/sql-tests/results/extract.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 100 +-- Number of queries: 125 -- !query @@ -197,7 +197,7 @@ struct -- !query select extract(hour from c), extract(hour from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -205,7 +205,7 @@ struct -- !query select extract(h from c), extract(h from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -213,7 +213,7 @@ struct -- !query select extract(hours from c), extract(hours from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -221,7 +221,7 @@ struct -- !query select extract(hr from c), extract(hr from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -229,7 +229,7 @@ struct -- !query select extract(hrs from c), extract(hrs from i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -519,7 +519,7 @@ struct -- !query select date_part('hour', c), date_part('hour', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -527,7 +527,7 @@ struct -- !query select date_part('h', c), date_part('h', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -535,7 +535,7 @@ struct -- !query select date_part('hours', c), date_part('hours', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -543,7 +543,7 @@ struct -- !query select date_part('hr', c), date_part('hr', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -551,7 +551,7 @@ struct -- !query select date_part('hrs', c), date_part('hrs', i) from t -- !query schema -struct +struct -- !query output 7 16 @@ -805,3 +805,208 @@ select extract(month from to_timestamp(c) - i) from t struct -- !query output 8 + + +-- !query +select extract(YEAR from interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +2 + + +-- !query +select date_part('YEAR', interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +2 + + +-- !query +select extract(YEAR from -interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +-2 + + +-- !query +select extract(MONTH from interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +1 + + +-- !query +select date_part('MONTH', interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +1 + + +-- !query +select extract(MONTH from -interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +-1 + + +-- !query +select date_part(NULL, interval '2-1' YEAR TO MONTH) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select extract(DAY from interval '2-1' YEAR TO MONTH) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Literals of type 'DAY' are currently not supported for the year-month interval type.; line 1 pos 7 + + +-- !query +select date_part('DAY', interval '2-1' YEAR TO MONTH) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Literals of type 'DAY' are currently not supported for the year-month interval type.; line 1 pos 7 + + +-- !query +select date_part('not_supported', interval '2-1' YEAR TO MONTH) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Literals of type 'not_supported' are currently not supported for the year-month interval type.; line 1 pos 7 + + +-- !query +select extract(DAY from interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +123 + + +-- !query +select date_part('DAY', interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +123 + + +-- !query +select extract(DAY from -interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +-123 + + +-- !query +select extract(HOUR from interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +12 + + +-- !query +select date_part('HOUR', interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +12 + + +-- !query +select extract(HOUR from -interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +-12 + + +-- !query +select extract(MINUTE from interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +34 + + +-- !query +select date_part('MINUTE', interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +34 + + +-- !query +select extract(MINUTE from -interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +-34 + + +-- !query +select extract(SECOND from interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +56.789123 + + +-- !query +select date_part('SECOND', interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +56.789123 + + +-- !query +select extract(SECOND from -interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +-56.789123 + + +-- !query +select date_part(NULL, interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select extract(MONTH from interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Literals of type 'MONTH' are currently not supported for the day-time interval type.; line 1 pos 7 + + +-- !query +select date_part('not_supported', interval '123 12:34:56.789123123' DAY TO SECOND) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Literals of type 'not_supported' are currently not supported for the day-time interval type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 6dc02ead9d..f249908163 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 44 +-- Number of queries: 52 -- !query @@ -1087,3 +1087,227 @@ struct 3 NULL 2 3 NULL 2 3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ())) +-- !query schema +struct +-- !query output +1 1 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), GROUPING SETS(ROLLUP(a, b))) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b, a, b), (a, b, a), (a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b, a, b), (a, b, a), (a, b))) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(ROLLUP(a, b), CUBE(a, b)) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS(GROUPING SETS((a, b), (a), ()), GROUPING SETS((a, b), (a), (b), ())) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 + + +-- !query +SELECT a, b, count(1) FROM testData GROUP BY a, GROUPING SETS((a, b), (a), (), (a, b), (a), (b), ()) +-- !query schema +struct +-- !query output +1 1 1 +1 1 1 +1 1 1 +1 2 1 +1 2 1 +1 2 1 +1 NULL 2 +1 NULL 2 +1 NULL 2 +1 NULL 2 +2 1 1 +2 1 1 +2 1 1 +2 2 1 +2 2 1 +2 2 1 +2 NULL 2 +2 NULL 2 +2 NULL 2 +2 NULL 2 +3 1 1 +3 1 1 +3 1 1 +3 2 1 +3 2 1 +3 2 1 +3 NULL 2 +3 NULL 2 +3 NULL 2 +3 NULL 2 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index b5471a785a..1d8c44c291 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 64 +-- Number of queries: 62 -- !query @@ -642,25 +642,3 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 struct -- !query output 1.0000 1 - - --- !query -SELECT not(a IS NULL), count(*) AS c -FROM testData -GROUP BY a IS NULL --- !query schema -struct<(NOT (a IS NULL)):boolean,c:bigint> --- !query output -false 2 -true 7 - - --- !query -SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c -FROM testData -GROUP BY a IS NULL --- !query schema -struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint> --- !query output -0.7604953758285915 7 -1.0 2 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index d525d8044a..a2cbea2906 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 118 +-- Number of queries: 134 -- !query @@ -774,6 +774,44 @@ select interval 30 days days days -----------------------------^^^ +-- !query +SELECT INTERVAL '178956970-7' YEAR TO MONTH +-- !query schema +struct +-- !query output +178956970-7 + + +-- !query +SELECT INTERVAL '178956970-8' YEAR TO MONTH +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Error parsing interval year-month string: integer overflow(line 1, pos 16) + +== SQL == +SELECT INTERVAL '178956970-8' YEAR TO MONTH +----------------^^^ + + +-- !query +SELECT INTERVAL '-178956970-8' YEAR TO MONTH +-- !query schema +struct +-- !query output +-178956970-8 + + +-- !query +SELECT INTERVAL -'178956970-8' YEAR TO MONTH +-- !query schema +struct +-- !query output +-178956970-8 + + -- !query create temporary view interval_arithmetic as select CAST(dateval AS date), CAST(tsval AS timestamp), dateval as strval from values @@ -1210,3 +1248,107 @@ select interval '-\t 1' day struct -- !query output -1 days + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 2 +-- !query schema +struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 2):year-month interval> +-- !query output +-89478485-4 + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / 5 +-- !query schema +struct<(INTERVAL '-178956970-8' YEAR TO MONTH / 5):year-month interval> +-- !query output +-35791394-2 + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1L +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow + + +-- !query +SELECT (INTERVAL '-178956970-8' YEAR TO MONTH) / -1.0D +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +not in range + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 2 +-- !query schema +struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 2):day-time interval> +-- !query output +-53375995 14:00:27.387904000 + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / 5 +-- !query schema +struct<(INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND / 5):day-time interval> +-- !query output +-21350398 05:36:10.955162000 + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1L +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow in integral divide. + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0 +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +Overflow + + +-- !query +SELECT (INTERVAL '-106751991 04:00:54.775808' DAY TO SECOND) / -1.0D +-- !query schema +struct<> +-- !query output +java.lang.ArithmeticException +not in range diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 1d7e9cdb43..6f94e742b8 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 44 +-- Number of queries: 47 -- !query @@ -26,6 +26,24 @@ struct<> +-- !query +CREATE OR REPLACE TEMPORARY VIEW complex_trans AS SELECT * FROM VALUES +(1, 1), +(1, 1), +(2, 2), +(2, 2), +(3, 3), +(2, 2), +(3, 3), +(1, 1), +(3, 3) +as complex_trans(a, b) +-- !query schema +struct<> +-- !query output + + + -- !query SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -717,3 +735,49 @@ SELECT TRANSFORM(b, MAX(a) AS max_a, CAST(sum(c) AS STRING)) FROM script_trans WHERE a <= 2 GROUP BY b + + +-- !query +FROM ( + SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) + FROM complex_trans + CLUSTER BY a +) map_output +SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) +-- !query schema +struct +-- !query output +1 1 +1 1 +1 1 +2 2 +2 2 +2 2 +3 3 +3 3 +3 3 + + +-- !query +FROM ( + SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) + FROM complex_trans + ORDER BY a +) map_output +SELECT TRANSFORM(a, b) + USING 'cat' AS (a, b) +-- !query schema +struct +-- !query output +1 1 +1 1 +1 1 +2 2 +2 2 +2 2 +3 3 +3 3 +3 3 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index c377658722..7443b95582 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 46 +-- Number of queries: 50 -- !query @@ -211,6 +211,71 @@ NULL NULL NULL 2020-12-30 16:00:00 b 1.6093728E9 +-- !query +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData +ORDER BY cate, val_timestamp +-- !query schema +struct +-- !query output +NULL NULL NULL +2017-07-31 17:00:00 NULL 1.5015456E9 +2017-07-31 17:00:00 a 1.5016970666666667E9 +2017-07-31 17:00:00 a 1.5016970666666667E9 +2017-08-05 23:13:20 a 1.502E9 +2020-12-30 16:00:00 a 1.6093728E9 +2017-07-31 17:00:00 b 1.5022728E9 +2017-08-17 13:00:00 b 1.503E9 +2020-12-30 16:00:00 b 1.6093728E9 + + +-- !query +SELECT val_timestamp, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_timestamp +RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData +ORDER BY cate, val_timestamp +-- !query schema +struct +-- !query output +NULL NULL NULL +2017-07-31 17:00:00 NULL 1.5015456E9 +2017-07-31 17:00:00 a 1.5015456E9 +2017-07-31 17:00:00 a 1.5015456E9 +2017-08-05 23:13:20 a 1.502E9 +2020-12-30 16:00:00 a 1.6093728E9 +2017-07-31 17:00:00 b 1.5015456E9 +2017-08-17 13:00:00 b 1.503E9 +2020-12-30 16:00:00 b 1.6093728E9 + + +-- !query +SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND interval '1-1' year to month FOLLOWING) FROM testData +ORDER BY cate, val_date +-- !query schema +struct +-- !query output +NULL NULL NULL +2017-08-01 NULL 1.5015456E9 +2017-08-01 a 1.5016970666666667E9 +2017-08-01 a 1.5016970666666667E9 +2017-08-02 a 1.502E9 +2020-12-31 a 1.6093728E9 +2017-08-01 b 1.5022728E9 +2017-08-03 b 1.503E9 +2020-12-31 b 1.6093728E9 + + +-- !query +SELECT val_date, cate, avg(val_timestamp) OVER(PARTITION BY cate ORDER BY val_date +RANGE BETWEEN CURRENT ROW AND interval '1 2:3:4.001' day to second FOLLOWING) FROM testData +ORDER BY cate, val_date +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.cate ORDER BY testdata.val_date ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND INTERVAL '1 02:03:04.001' DAY TO SECOND FOLLOWING)' due to data type mismatch: The data type 'date' used in the order specification does not match the data type 'day-time interval' which is used in the range frame.; line 1 pos 46 + + -- !query SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 7d1e4ff040..c06544ee00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.{InMemoryPartitionTableCatalog, SchemaRequiredDataSource} +import org.apache.spark.sql.connector.SchemaRequiredDataSource +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5108c0169b..ad5d73c774 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1686,6 +1686,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("a", IntegerType, nullable = true)))) } + test("SPARK-35213: chained withField operations should have correct schema for new columns") { + val df = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("data", NullType)))) + + checkAnswer( + df.withColumn("data", struct() + .withField("a", struct()) + .withField("b", struct()) + .withField("a.aa", lit("aa1")) + .withField("b.ba", lit("ba1")) + .withField("a.ab", lit("ab1"))), + Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil, + StructType(Seq( + StructField("data", StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType, nullable = false), + StructField("ab", StringType, nullable = false) + )), nullable = false), + StructField("b", StructType(Seq( + StructField("ba", StringType, nullable = false) + )), nullable = false) + )), nullable = false) + )) + ) + } + + test("SPARK-35213: optimized withField operations should maintain correct nested struct " + + "ordering") { + val df = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("data", NullType)))) + + checkAnswer( + df.withColumn("data", struct() + .withField("a", struct().withField("aa", lit("aa1"))) + .withField("b", struct().withField("ba", lit("ba1"))) + ) + .withColumn("data", col("data").withField("b.bb", lit("bb1"))) + .withColumn("data", col("data").withField("a.ab", lit("ab1"))), + Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil, + StructType(Seq( + StructField("data", StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType, nullable = false), + StructField("ab", StringType, nullable = false) + )), nullable = false), + StructField("b", StructType(Seq( + StructField("ba", StringType, nullable = false), + StructField("bb", StringType, nullable = false) + )), nullable = false) + )), nullable = false) + )) + ) + } test("dropFields should throw an exception if called on a non-StructType column") { intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c53bcf045d..c6f6cbdbf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1135,7 +1135,7 @@ class DataFrameAggregateSuite extends QueryTest val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType), @@ -1173,7 +1173,7 @@ class DataFrameAggregateSuite extends QueryTest val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil) + Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("avg(year-month)", YearMonthIntervalType), @@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df2.select(avg($"day-time")), Nil) } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") + + val df3 = df.filter($"class" > 4) + val avgDF3 = df3.select(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF3, Row(null, null) :: Nil) + + val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF4, Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 35e732e084..8aef27a1b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -25,8 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, InMemoryTableCatalog, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 688288abee..13d1285401 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -506,6 +506,14 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite checkKeywordsExistsInExplain(df2, keywords = "[key1=value1, KEY2=VALUE2]") } } + + test("SPARK-35225: Handle empty output for analyzed plan") { + withTempView("test") { + checkKeywordsExistsInExplain( + sql("CREATE TEMPORARY VIEW test AS SELECT 1"), + "== Analyzed Logical Plan ==\nCreateViewCommand") + } + } } class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 807c6b2a67..2f56fbaf7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.expressions.Hex -import org.apache.spark.sql.connector.InMemoryPartitionTableCatalog +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index efb87dafe0..d83d1a2755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -22,7 +22,7 @@ import java.util.Collections import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan} -import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala index c973e2ba30..44fbc639a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.{DataFrame, Row, SaveMode} -import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, Table, TableCatalog} class DataSourceV2SQLSessionCatalogSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala index 8922eea8e0..3ef242f90f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryPartitionTableCatalog, InMemoryTableCatalog, StagingInMemoryTableCatalog} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 3aad644655..076dad7530 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsCatalogOptions, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.execution.QueryExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 9beef690cb..847953e09c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, V1Scan} import org.apache.spark.sql.execution.RowDataSourceScanExec diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 45ddc6a6fc..7effc747ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 6fd9dc4e39..db4a9c153c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} -import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 40f25d5599..c845dd81f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ @@ -43,6 +44,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU import testImplicits._ import ScriptTransformationIOSchema._ + protected def defaultSerDe(): String + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ @@ -599,6 +602,37 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'e.cast("string")).collect()) } } + + test("SPARK-35220: DayTimeIntervalType/YearMonthIntervalType show different " + + "between hive serde and row format delimited\t") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (Duration.ofDays(1), Period.ofMonths(10)) + ).toDF("a", "b") + df.createTempView("v") + + if (defaultSerDe == "hive-serde") { + checkAnswer(sql( + """ + |SELECT TRANSFORM(a, b) + | USING 'cat' AS (a, b) + |FROM v + |""".stripMargin), + identity, + Row("1 00:00:00.000000000", "0-10") :: Nil) + } else { + checkAnswer(sql( + """ + |SELECT TRANSFORM(a, b) + | USING 'cat' AS (a, b) + |FROM v + |""".stripMargin), + identity, + Row("INTERVAL '1 00:00:00' DAY TO SECOND", "INTERVAL '0-10' YEAR TO MONTH") :: Nil) + } + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala index f16265ee61..f8366b3f7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/HiveResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.time.{Duration, Period} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog import org.apache.spark.sql.execution.HiveResult._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index f69aea3729..5638743b76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.test.SharedSparkSession class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { import testImplicits._ + override protected def defaultSerDe(): String = "row-format-delimited" + override def createScriptTransformationExec( script: String, output: Seq[Attribute], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 31b6921132..2598d3ba8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1575,4 +1575,19 @@ class AdaptiveQueryExecSuite checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS) } } + + test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { + withTable("t") { + withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + spark.sql("CREATE TABLE t (c1 int) USING PARQUET") + val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") + assert( + collect(adaptive) { + case c @ CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.length == 1 => c + }.length == 1 + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index f58d3246f5..1684633c92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.arrow +import org.apache.arrow.vector.IntervalDayVector + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ @@ -54,6 +56,8 @@ class ArrowWriterSuite extends SparkFunSuite { case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) case TimestampType => reader.getLong(rowId) + case YearMonthIntervalType => reader.getInt(rowId) + case DayTimeIntervalType => reader.getLong(rowId) } assert(value === datum) } @@ -73,6 +77,33 @@ class ArrowWriterSuite extends SparkFunSuite { check(DateType, Seq(0, 1, 2, null, 4)) check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(NullType, Seq(null, null, null)) + check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)) + check(DayTimeIntervalType, Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), + (Long.MinValue + 808L))) + } + + test("long overflow for DayTimeIntervalType") + { + val schema = new StructType().add("value", DayTimeIntervalType, nullable = true) + val writer = ArrowWriter.create(schema, null) + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector] + + valueVector.set(0, 106751992, 0) + valueVector.set(1, 106751991, Int.MaxValue) + + // first long overflow for test Math.multiplyExact() + val msg = intercept[java.lang.ArithmeticException] { + reader.getLong(0) + }.getMessage + assert(msg.equals("long overflow")) + + // second long overflow for test Math.addExact() + val msg1 = intercept[java.lang.ArithmeticException] { + reader.getLong(1) + }.getMessage + assert(msg1.equals("long overflow")) + writer.root.close() } test("get multiple") { @@ -97,6 +128,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDoubles(0, data.size) case DateType => reader.getInts(0, data.size) case TimestampType => reader.getLongs(0, data.size) + case YearMonthIntervalType => reader.getInts(0, data.size) + case DayTimeIntervalType => reader.getLongs(0, data.size) } assert(values === data) @@ -111,6 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, (0 until 10).map(_.toDouble)) check(DateType, (0 until 10)) check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") + check(YearMonthIntervalType, (0 until 10)) + check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong)) } test("array") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index d77ef6e6bd..b8d7b774d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -80,7 +80,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", @@ -89,7 +89,16 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { + f() + } + } + + benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", @@ -116,7 +125,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", @@ -125,7 +134,16 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { + f() + } + } + + benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", @@ -151,7 +169,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = F", numIters = 3) { _ => + benchmark.addCase("codegen = T, hashmap = F", numIters = 3) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", @@ -160,7 +178,16 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = T", numIters = 5) { _ => + benchmark.addCase("codegen = T, row-based hashmap = T", numIters = 5) { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { + f() + } + } + + benchmark.addCase("codegen = T, vectorized hashmap = T", numIters = 5) { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", @@ -186,7 +213,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = F") { _ => + benchmark.addCase("codegen = T, hashmap = F") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", @@ -195,7 +222,16 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = T") { _ => + benchmark.addCase("codegen = T, row-based hashmap = T") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { + f() + } + } + + benchmark.addCase("codegen = T, vectorized hashmap = T") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", @@ -231,7 +267,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = F") { _ => + benchmark.addCase("codegen = T, hashmap = F") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "false", @@ -240,7 +276,16 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hashmap = T") { _ => + benchmark.addCase("codegen = T, row-based hashmap = T") { _ => + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", + SQLConf.ENABLE_VECTORIZED_HASH_MAP.key -> "false") { + f() + } + } + + benchmark.addCase("codegen = T, vectorized hashmap = T") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> "true", @@ -291,7 +336,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hugeMethodLimit = 10000") { _ => + benchmark.addCase("codegen = T, hugeMethodLimit = 10000") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "10000") { @@ -299,7 +344,7 @@ object AggregateBenchmark extends SqlBasedBenchmark { } } - benchmark.addCase("codegen = T hugeMethodLimit = 1500") { _ => + benchmark.addCase("codegen = T, hugeMethodLimit = 1500") { _ => withSQLConf( SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "1500") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index 1f47744ce4..ba683c049a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.InMemoryPartitionTableCatalog +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala index 2dd80b7bb6..bed04f4f26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} -import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier} +import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier, InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.test.SharedSparkSession /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index 7a2c136eea..bafb6608c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.BasicInMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 765d2fc584..ac5c28953a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -351,6 +351,43 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-34638: nested column prune on generator output") { + val query1 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first") + checkScan(query1, "struct>>") + checkAnswer(query1, Row("Susan") :: Nil) + + // Currently we don't prune multiple field case. + val query2 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.middle") + checkScan(query2, "struct>>") + checkAnswer(query2, Row("Susan", "Z.") :: Nil) + + val query3 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.first", "friend.middle", "friend") + checkScan(query3, "struct>>") + checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil) + } + + testSchemaPruning("SPARK-34638: nested column prune on generator output - case-sensitivity") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val query1 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.First") + checkScan(query1, "struct>>") + checkAnswer(query1, Row("Susan") :: Nil) + + val query2 = spark.table("contacts") + .select(explode(col("friends")).as("friend")) + .select("friend.MIDDLE") + checkScan(query2, "struct>>") + checkAnswer(query2, Row("Z.") :: Nil) + } + } + testSchemaPruning("select one deep nested complex field after repartition") { val query = sql("select * from contacts") .repartition(100) @@ -816,4 +853,21 @@ abstract class SchemaPruningSuite Row("John", "Y.") :: Nil) } } + + test("SPARK-34638: queries should not fail on unsupported cases") { + withTable("nested_array") { + sql("select * from values array(array(named_struct('a', 1, 'b', 3), " + + "named_struct('a', 2, 'b', 4))) T(items)").write.saveAsTable("nested_array") + val query = sql("select d.a from (select explode(c) d from " + + "(select explode(items) c from nested_array))") + checkAnswer(query, Row(1) :: Row(2) :: Nil) + } + + withTable("map") { + sql("select * from values map(1, named_struct('a', 1, 'b', 3), " + + "2, named_struct('a', 2, 'b', 4)) T(items)").write.saveAsTable("map") + val query = sql("select d.a from (select explode(items) (c, d) from map)") + checkAnswer(query, Row(1) :: Row(2) :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 7bd068c0f9..eee8e2ecc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -633,4 +633,20 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession { } } } + + test("SPARK-34897: Support reconcile schemas based on index after nested column pruning") { + withTable("t1") { + spark.sql( + """ + |CREATE TABLE t1 ( + | _col0 INT, + | _col1 STRING, + | _col2 STRUCT) + |USING ORC + |""".stripMargin) + + spark.sql("INSERT INTO t1 values(1, '2', struct('a', 'b', 'c', 10L))") + checkAnswer(spark.sql("SELECT _col0, _col2.c1 FROM t1"), Seq(Row(1, "a"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala index e2fa03ff23..440b0dc08e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/CustomMetricsSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.metric import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.metric.{CustomAvgMetric, CustomSumMetric} class CustomMetricsSuite extends SparkFunSuite { test("Build/parse custom metric metric type") { - Seq(new CustomSumMetric, new CustomAvgMetric).foreach { customMetric => + Seq(new TestCustomSumMetric, new TestCustomAvgMetric).foreach { customMetric => val metricType = CustomMetrics.buildV2CustomMetricTypeName(customMetric) assert(metricType == CustomMetrics.V2_CUSTOM + "_" + customMetric.getClass.getCanonicalName) @@ -33,7 +34,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomSumMetric") { - val metric = new CustomSumMetric + val metric = new TestCustomSumMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == metricValues1.sum.toString) @@ -43,7 +44,7 @@ class CustomMetricsSuite extends SparkFunSuite { } test("Built-in CustomAvgMetric") { - val metric = new CustomAvgMetric + val metric = new TestCustomAvgMetric val metricValues1 = Array(0L, 1L, 5L, 5L, 7L, 10L) assert(metric.aggregateTaskMetrics(metricValues1) == "4.667") @@ -52,3 +53,13 @@ class CustomMetricsSuite extends SparkFunSuite { assert(metric.aggregateTaskMetrics(metricValues2) == "0") } } + +private[spark] class TestCustomSumMetric extends CustomSumMetric { + override def name(): String = "CustomSumMetric" + override def description(): String = "Sum up CustomMetric" +} + +private[spark] class TestCustomAvgMetric extends CustomAvgMetric { + override def name(): String = "CustomAvgMetric" + override def description(): String = "Average CustomMetric" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index a58265124d..612b74a661 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -37,7 +37,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.connector.{CustomMetric, CustomTaskMetric, RangeInputPartition, SimpleScanBuilder} +import org.apache.spark.sql.connector.{RangeInputPartition, SimpleScanBuilder} +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index b85a668e5b..90127557f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -117,6 +117,21 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } } + test(s"SPARK-35168: ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} should respect" + + s" ${SQLConf.SHUFFLE_PARTITIONS.key}") { + spark.sessionState.conf.clear() + try { + sql(s"SET ${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key}=true") + sql(s"SET ${SQLConf.COALESCE_PARTITIONS_ENABLED.key}=true") + sql(s"SET ${SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key}=1") + sql(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=2") + checkAnswer(sql(s"SET ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}"), + Row(SQLConf.SHUFFLE_PARTITIONS.key, "2")) + } finally { + spark.sessionState.conf.clear() + } + } + test("SPARK-31234: reset will not change static sql configs and spark core configs") { val conf = spark.sparkContext.getConf.getAll.toMap val appName = conf.get("spark.app.name") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 02f91399fc..0e2fcfbd46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -54,27 +54,31 @@ class ContinuousSuiteBase extends StreamTest { protected def waitForRateSourceCommittedValue( query: ContinuousExecution, - desiredValue: Long, + partitionIdToDesiredValue: Map[Int, Long], maxWaitTimeMs: Long): Unit = { - def readHighestCommittedValue(c: ContinuousExecution): Option[Long] = { + def readCommittedValues(c: ContinuousExecution): Option[Map[Int, Long]] = { c.committedOffsets.lastOption.map { case (_, offset) => offset match { case o: RateStreamOffset => - o.partitionToValueAndRunTimeMs.map { - case (_, ValueRunTimeMsPair(value, _)) => value - }.max + o.partitionToValueAndRunTimeMs.mapValues(_.value).toMap } } } + def reachDesiredValues: Boolean = { + val committedValues = readCommittedValues(query).getOrElse(Map.empty) + partitionIdToDesiredValue.forall { case (key, value) => + committedValues.contains(key) && committedValues(key) > value + } + } + val maxWait = System.currentTimeMillis() + maxWaitTimeMs - while (System.currentTimeMillis() < maxWait && - readHighestCommittedValue(query).getOrElse(Long.MinValue) < desiredValue) { + while (System.currentTimeMillis() < maxWait && !reachDesiredValues) { Thread.sleep(100) } if (System.currentTimeMillis() > maxWait) { logWarning(s"Couldn't reach desired value in $maxWaitTimeMs milliseconds!" + - s"Current highest committed value is ${readHighestCommittedValue(query)}") + s"Current committed values is ${readCommittedValues(query)}") } } @@ -264,7 +268,7 @@ class ContinuousSuite extends ContinuousSuiteBase { val expected = Set(0, 1, 2, 3) val continuousExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution] - waitForRateSourceCommittedValue(continuousExecution, expected.max, 20 * 1000) + waitForRateSourceCommittedValue(continuousExecution, Map(0 -> 2, 1 -> 3), 20 * 1000) query.stop() val results = spark.read.table("noharness").collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index 4c5c5e63ce..49e5218ea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableCatalog, InMemoryTableSessionCatalog} -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, V2TableWithV1Fallback} +import org.apache.spark.sql.connector.{FakeV2Provider, InMemoryTableSessionCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, SupportsRead, Table, TableCapability, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.streaming.{MemoryStream, MemoryStreamScanBuilder} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 273658fcfa..41d1156875 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.test import java.io.File -import java.util.Locale +import java.util.{Locale, Random} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ @@ -1219,4 +1219,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } } + + test("SPARK-26164: Allow concurrent writers for multiple partitions and buckets") { + withTable("t1", "t2") { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) + val df = spark.range(200).map(_ => { + val n = r.nextInt() + (n, n.toString, n % 5) + }).toDF("k1", "k2", "part") + df.write.format("parquet").saveAsTable("t1") + spark.sql("CREATE TABLE t2(k1 int, k2 string, part int) USING parquet PARTITIONED " + + "BY (part) CLUSTERED BY (k1) INTO 3 BUCKETS") + val queryToInsertTable = "INSERT OVERWRITE TABLE t2 SELECT k1, k2, part FROM t1" + + Seq( + // Single writer + 0, + // Concurrent writers without fallback + 200, + // concurrent writers with fallback + 3 + ).foreach { maxWriters => + withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key -> maxWriters.toString) { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + + withSQLConf(SQLConf.MAX_RECORDS_PER_FILE.key -> "1") { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + } + } + } + } + } } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 4108d0f04b..729d3f4142 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -72,6 +72,13 @@ test-jar test + + org.apache.parquet + parquet-hadoop + ${parquet.version} + test-jar + test +