diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index a9d0e41aa3cbb..7b692201dac11 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -145,6 +145,8 @@ List buildClassPath(String appClassPath) throws IOException { boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES")); boolean isTesting = "1".equals(getenv("SPARK_TESTING")); + boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); + String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); if (prependClasses || isTesting) { String scala = getScalaVersion(); List projects = Arrays.asList( @@ -176,6 +178,9 @@ List buildClassPath(String appClassPath) throws IOException { "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + "assembly."); } + boolean shouldPrePendSparkHive = isJarAvailable(jarsDir, "spark-hive_"); + boolean shouldPrePendSparkHiveThriftServer = + shouldPrePendSparkHive && isJarAvailable(jarsDir, "spark-hive-thriftserver_"); for (String project : projects) { // Do not use locally compiled class files for Spark server because it should use shaded // dependencies. @@ -185,6 +190,24 @@ List buildClassPath(String appClassPath) throws IOException { if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL")) && project.equals("sql/core")) { continue; } + // SPARK-49534: The assumption here is that if `spark-hive_xxx.jar` is not in the + // classpath, then the `-Phive` profile was not used during package, and therefore + // the Hive-related jars should also not be in the classpath. To avoid failure in + // loading the SPI in `DataSourceRegister` under `sql/hive`, no longer prepend `sql/hive`. + if (!shouldPrePendSparkHive && project.equals("sql/hive")) { + continue; + } + // SPARK-49534: Meanwhile, due to the strong dependency of `sql/hive-thriftserver` + // on `sql/hive`, the prepend for `sql/hive-thriftserver` will also be excluded + // if `spark-hive_xxx.jar` is not in the classpath. On the other hand, if + // `spark-hive-thriftserver_xxx.jar` is not in the classpath, then the + // `-Phive-thriftserver` profile was not used during package, and therefore, + // jars such as hive-cli and hive-beeline should also not be included in the classpath. + // To avoid the inelegant startup failures of tools such as spark-sql, in this scenario, + // `sql/hive-thriftserver` will no longer be prepended to the classpath. + if (!shouldPrePendSparkHiveThriftServer && project.equals("sql/hive-thriftserver")) { + continue; + } addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); } @@ -205,8 +228,6 @@ List buildClassPath(String appClassPath) throws IOException { // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and // propagate the test classpath appropriately. For normal invocation, look for the jars // directory under SPARK_HOME. - boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); - String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); if (jarsDir != null) { // Place slf4j-api-* jar first to be robust for (File f: new File(jarsDir).listFiles()) { @@ -265,6 +286,24 @@ private void addToClassPath(Set cp, String entries) { } } + /** + * Checks if a JAR file with a specific prefix is available in the given directory. + * + * @param jarsDir the directory to search for JAR files + * @param jarNamePrefix the prefix of the JAR file name to look for + * @return true if a JAR file with the specified prefix is found, false otherwise + */ + private boolean isJarAvailable(String jarsDir, String jarNamePrefix) { + if (jarsDir != null) { + for (File f : new File(jarsDir).listFiles()) { + if (f.getName().startsWith(jarNamePrefix)) { + return true; + } + } + } + return false; + } + String getScalaVersion() { String scala = getenv("SPARK_SCALA_VERSION"); if (scala != null) { diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 92d4a3357319f..4be345201ba65 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -43,6 +43,7 @@ ) from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote from pyspark import pandas as ps +from pyspark.pandas.spark import functions as SF from pyspark.pandas._typing import Label from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale from pyspark.pandas.data_type_ops.base import DataTypeOps @@ -938,19 +939,10 @@ def attach_distributed_sequence_column( +--------+---+ """ if len(sdf.columns) > 0: - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - from pyspark.sql.connect.expressions import DistributedSequenceID - - return sdf.select( - ConnectColumn(DistributedSequenceID()).alias(column_name), - "*", - ) - else: - return PySparkDataFrame( - sdf._jdf.toDF().withSequenceColumn(column_name), - sdf.sparkSession, - ) + return sdf.select( + SF.distributed_sequence_id().alias(column_name), + "*", + ) else: cnt = sdf.count() if cnt > 0: diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 6aaa63956c14b..4bcf07f6f6503 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column: return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) +def distributed_sequence_id() -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function + + return _invoke_function("distributed_sequence_id") + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id()) + + def collect_top_k(col: Column, num: int, reverse: bool) -> Column: if is_remote(): from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 344ba8d009ac4..1bdd2dbd8f016 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -7305,36 +7305,36 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> | b| 2| +---+---+ >>> w = Window.partitionBy("c1").orderBy("c2") - >>> df.withColumn("previos_value", lag("c2").over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| NULL| - | a| 2| 1| - | a| 3| 2| - | b| 2| NULL| - | b| 8| 2| - +---+---+-------------+ - >>> df.withColumn("previos_value", lag("c2", 1, 0).over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| 0| - | a| 2| 1| - | a| 3| 2| - | b| 2| 0| - | b| 8| 2| - +---+---+-------------+ - >>> df.withColumn("previos_value", lag("c2", 2, -1).over(w)).show() - +---+---+-------------+ - | c1| c2|previos_value| - +---+---+-------------+ - | a| 1| -1| - | a| 2| -1| - | a| 3| 1| - | b| 2| -1| - | b| 8| -1| - +---+---+-------------+ + >>> df.withColumn("previous_value", lag("c2").over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| NULL| + | a| 2| 1| + | a| 3| 2| + | b| 2| NULL| + | b| 8| 2| + +---+---+--------------+ + >>> df.withColumn("previous_value", lag("c2", 1, 0).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| 0| + | a| 2| 1| + | a| 3| 2| + | b| 2| 0| + | b| 8| 2| + +---+---+--------------+ + >>> df.withColumn("previous_value", lag("c2", 2, -1).over(w)).show() + +---+---+--------------+ + | c1| c2|previous_value| + +---+---+--------------+ + | a| 1| -1| + | a| 2| -1| + | a| 3| 1| + | b| 2| -1| + | b| 8| -1| + +---+---+--------------+ """ from pyspark.sql.classic.column import _to_java_column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 870571b533d09..0fab60a948423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2010,14 +2010,6 @@ class Dataset[T] private[sql]( // For Python API //////////////////////////////////////////////////////////////////////////// - /** - * It adds a new long column with the name `name` that increases one by one. - * This is for 'distributed-sequence' default index in pandas API on Spark. - */ - private[sql] def withSequenceColumn(name: String) = { - select(column(DistributedSequenceID()).alias(name), col("*")) - } - /** * Converts a JavaRDD to a PythonRDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 7dbc586f64730..93082740cca64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging { def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = Column.internalFn("pandas_covar", col1, col2, lit(ddof)) + /** + * A long column that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + def distributed_sequence_id(): Column = + Column.internalFn("distributed_sequence_id") + def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala index aa393211a1c15..06fdc6c53bc4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -40,8 +41,14 @@ trait AsyncLogPurge extends Logging { private val purgeRunning = new AtomicBoolean(false) + private val statefulMetadataPurgeRunning = new AtomicBoolean(false) + protected def purge(threshold: Long): Unit + // This method is used to purge the oldest OperatorStateMetadata and StateSchema files + // which are written per run. + protected def purgeStatefulMetadata(plan: SparkPlan): Unit + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) protected def purgeAsync(batchId: Long): Unit = { @@ -62,6 +69,24 @@ trait AsyncLogPurge extends Logging { } } + protected def purgeStatefulMetadataAsync(plan: SparkPlan): Unit = { + if (statefulMetadataPurgeRunning.compareAndSet(false, true)) { + asyncPurgeExecutorService.execute(() => { + try { + purgeStatefulMetadata(plan) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + statefulMetadataPurgeRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + protected def asyncLogPurgeShutdown(): Unit = { ThreadUtils.shutdown(asyncPurgeExecutorService) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index f59cdca8aefec..285494543533c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -842,6 +842,10 @@ class MicroBatchExecution( markMicroBatchExecutionStart(execCtx) + if (execCtx.previousContext.isEmpty) { + purgeStatefulMetadataAsync(execCtx.executionPlan.executedPlan) + } + val nextBatch = new Dataset(execCtx.executionPlan, ExpressionEncoder(execCtx.executionPlan.analyzed.schema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ee5f806a3fb35..4b1b9e02a242a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -41,8 +41,10 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream} import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException} +import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV2FileManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend import org.apache.spark.sql.streaming._ @@ -686,6 +688,31 @@ abstract class StreamExecution( offsetLog.purge(threshold) commitLog.purge(threshold) } + + protected def purgeStatefulMetadata(plan: SparkPlan): Unit = { + plan.collect { case statefulOperator: StatefulOperator => + statefulOperator match { + case ssw: StateStoreWriter => + ssw.operatorStateMetadataVersion match { + case 2 => + // checkpointLocation of the operator is runId/state, and commitLog path is + // runId/commits, so we want the parent of the checkpointLocation to get the + // commit log path. + val parentCheckpointLocation = + new Path(statefulOperator.getStateInfo.checkpointLocation).getParent + + val fileManager = new OperatorStateMetadataV2FileManager( + parentCheckpointLocation, + sparkSession, + ssw + ) + fileManager.purgeMetadataFiles() + case _ => + } + case _ => + } + } + } } object StreamExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index 3e68b3975e662..aa2f332afeff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -23,14 +23,14 @@ import java.nio.charset.StandardCharsets import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter} import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog, StateStoreWriter} import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS} import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter} @@ -358,3 +358,121 @@ class OperatorStateMetadataV2Reader( } } } + +/** + * A helper class to manage the metadata files for the operator state checkpoint. + * This class is used to manage the metadata files for OperatorStateMetadataV2, and + * provides utils to purge the oldest files such that we only keep the metadata files + * for which a commit log is present + * @param checkpointLocation The root path of the checkpoint directory + * @param sparkSession The sparkSession that is used to access the hadoopConf + * @param stateStoreWriter The operator that this fileManager is being created for + */ +class OperatorStateMetadataV2FileManager( + checkpointLocation: Path, + sparkSession: SparkSession, + stateStoreWriter: StateStoreWriter) extends Logging { + + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val stateCheckpointPath = new Path(checkpointLocation, "state") + private val stateOpIdPath = new Path( + stateCheckpointPath, stateStoreWriter.getStateInfo.operatorId.toString) + private val commitLog = + new CommitLog(sparkSession, new Path(checkpointLocation, "commits").toString) + private val stateSchemaPath = stateStoreWriter.stateSchemaDirPath() + private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf) + + protected def isBatchFile(path: Path): Boolean = { + try { + path.getName.toLong + true + } catch { + case _: NumberFormatException => false + } + } + + /** + * A `PathFilter` to filter only batch files + */ + protected val batchFilesFilter: PathFilter = (path: Path) => isBatchFile(path) + + private def pathToBatchId(path: Path): Long = { + path.getName.toLong + } + + def purgeMetadataFiles(): Unit = { + val thresholdBatchId = findThresholdBatchId() + if (thresholdBatchId != 0) { + val earliestBatchIdKept = deleteMetadataFiles(thresholdBatchId) + // we need to delete everything from 0 to (earliestBatchIdKept - 1), inclusive + deleteSchemaFiles(earliestBatchIdKept - 1) + } + } + + // We only want to keep the metadata and schema files for which the commit + // log is present, so we will delete any file that precedes the batch for the oldest + // commit log + private def findThresholdBatchId(): Long = { + commitLog.listBatchesOnDisk.headOption.getOrElse(0L) + } + + private def deleteSchemaFiles(thresholdBatchId: Long): Unit = { + val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath) + val filesBeforeThreshold = schemaFiles.filter { path => + val batchIdInPath = path.getName.split("_").head.toLong + batchIdInPath <= thresholdBatchId + } + filesBeforeThreshold.foreach { path => + fm.delete(path) + } + } + + // Deletes all metadata files that are below a thresholdBatchId, except + // for the latest metadata file so that we have at least 1 metadata and schema + // file at all times per each stateful query + // Returns the batchId of the earliest schema file we want to keep + private def deleteMetadataFiles(thresholdBatchId: Long): Long = { + val metadataFiles = fm.list(metadataDirPath, batchFilesFilter) + + if (metadataFiles.isEmpty) { + return -1L // No files to delete + } + + // get all the metadata files for which we don't have commit logs + val sortedBatchIds = metadataFiles + .map(file => pathToBatchId(file.getPath)) + .filter(_ <= thresholdBatchId) + .sorted + + if (sortedBatchIds.isEmpty) { + return -1L + } + + // we don't want to delete the batchId right before the last one + val latestBatchId = sortedBatchIds.last + + metadataFiles.foreach { batchFile => + val batchId = pathToBatchId(batchFile.getPath) + if (batchId < latestBatchId) { + fm.delete(batchFile.getPath) + } + } + val latestMetadata = OperatorStateMetadataReader.createReader( + stateOpIdPath, + hadoopConf, + 2, + latestBatchId + ).read() + + // find the batchId of the earliest schema file we need to keep + val earliestBatchToKeep = latestMetadata match { + case Some(OperatorStateMetadataV2(_, stateStoreInfo, _)) => + val schemaFilePath = stateStoreInfo.head.stateSchemaFilePath + new Path(schemaFilePath).getName.split("_").head.toLong + case _ => 0 + } + + earliestBatchToKeep + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 47b1cb90e00a8..90eb634689b23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -186,6 +186,7 @@ class StateSchemaCompatibilityChecker( check(existingStateSchema, newSchema, ignoreValueSchema) } } + // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3 false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c092adb354c2d..3cb41710a22c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -166,17 +166,19 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp "number of state store instances") ) ++ stateStoreCustomMetrics ++ pythonMetrics - def stateSchemaFilePath(storeName: Option[String] = None): Path = { - def stateInfo = getStateInfo + // This method is only used to fetch the state schema directory path for + // operators that use StateSchemaV3, as prior versions only use a single + // set file path. + def stateSchemaDirPath( + storeName: Option[String] = None): Path = { val stateCheckpointPath = new Path(getStateInfo.checkpointLocation, - s"${stateInfo.operatorId.toString}") + s"${getStateInfo.operatorId.toString}") storeName match { case Some(storeName) => - val storeNamePath = new Path(stateCheckpointPath, storeName) - new Path(new Path(storeNamePath, "_metadata"), "schema") + new Path(new Path(stateCheckpointPath, "_stateSchema"), storeName) case None => - new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + new Path(new Path(stateCheckpointPath, "_stateSchema"), "default") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 8b4fa90b3119d..52b8d35e2fbf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogManager, SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -671,12 +672,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } else { CatalogTableType.MANAGED } - val location = if (storage.locationUri.isDefined) { - val locationStr = storage.locationUri.get.toString - Some(locationStr) - } else { - None - } + + // The location in UnresolvedTableSpec should be the original user-provided path string. + val location = CaseInsensitiveMap(options).get("path") val newOptions = OptionList(options.map { case (key, value) => (key, Literal(value).asInstanceOf[Expression]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala index 7248a2d3f056e..f0eef9ae1cbb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -426,8 +426,10 @@ final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFram import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined || - df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined + val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf( + SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined && + !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) + .isInstanceOf[DelegatingCatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index c2e6abf184b55..284ccc5d5bfe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -176,6 +176,10 @@ abstract class CompoundNestedStatementIteratorExec( handleLeaveStatement(leaveStatement) curr = None leaveStatement + case Some(iterateStatement: IterateStatementExec) => + handleIterateStatement(iterateStatement) + curr = None + iterateStatement case Some(statement: LeafStatementExec) => curr = if (localIterator.hasNext) Some(localIterator.next()) else None statement @@ -185,6 +189,9 @@ abstract class CompoundNestedStatementIteratorExec( case leaveStatement: LeaveStatementExec => handleLeaveStatement(leaveStatement) leaveStatement + case iterateStatement: IterateStatementExec => + handleIterateStatement(iterateStatement) + iterateStatement case other => other } } else { @@ -206,18 +213,35 @@ abstract class CompoundNestedStatementIteratorExec( stopIteration = false } - /** Actions to do when LEAVE statement is encountered to stop the execution of this compound. */ + /** Actions to do when LEAVE statement is encountered, to stop the execution of this compound. */ private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = { if (!leaveStatement.hasBeenMatched) { // Stop the iteration. stopIteration = true // TODO: Variable cleanup (once we add SQL script execution logic). + // TODO: Add interpreter tests as well. // Check if label has been matched. leaveStatement.hasBeenMatched = label.isDefined && label.get.equals(leaveStatement.label) } } + + /** + * Actions to do when ITERATE statement is encountered, to stop the execution of this compound. + */ + private def handleIterateStatement(iterateStatement: IterateStatementExec): Unit = { + if (!iterateStatement.hasBeenMatched) { + // Stop the iteration. + stopIteration = true + + // TODO: Variable cleanup (once we add SQL script execution logic). + // TODO: Add interpreter tests as well. + + // No need to check if label has been matched, since ITERATE statement is already + // not allowed in CompoundBody. + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 5e7feec149c97..031bf337cec90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAg import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { @@ -158,7 +157,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("disable bucketing on collated string column") { - spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) def createTable(bucketColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { @@ -760,7 +758,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("disable partition on collated string column") { - spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) def createTable(partitionColumns: String*): Unit = { val tableName = "test_partition_tbl" withTable(tableName) { @@ -950,7 +947,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table (a array) using parquet") sql(s"insert into $table values (array('aaa')), (array('AAA'))") - checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + val result = sql(s"select distinct a from $table").collect() + assert(result.length === 1) + val data = result.head.getSeq[String](0) + assert(data === Array("aaa") || data === Array("AAA")) } // map doesn't support aggregation withTable(table) { @@ -971,7 +971,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table (s struct) using parquet") sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))") - checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa"))) + val result = sql(s"select s.fld from $table group by s").collect() + assert(result.length === 1) + val data = result.head.getString(0) + assert(data === "aaa" || data === "AAA") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 310b5a62c908a..d888b09d76eac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{count, explode, sum, year} +import org.apache.spark.sql.functions.{col, count, explode, sum, year} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) // Test for AttachDistributedSequence - val df13 = df1.withSequenceColumn("seq") + val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*")) val df14 = df13.filter($"value" === "A2") assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b1c41033fd760..9bfbdda33c36d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest } test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") { - val ids = spark.range(10).repartition(5).withSequenceColumn("default_index") + val ids = spark.range(10).repartition(5).select( + distributed_sequence_id().alias("default_index"), col("id")) assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet) assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 7bbb6485c273f..fe078c5ae4413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -55,8 +55,7 @@ class DataSourceV2DataFrameSessionCatalogSuite "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - val format = spark.sessionState.conf.defaultDataSourceName - sql(s"CREATE TABLE same_name(id LONG) USING $format") + sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") spark.range(10).createTempView("same_name") spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) @@ -88,6 +87,15 @@ class DataSourceV2DataFrameSessionCatalogSuite assert(tableInfo.properties().get("provider") === v2Format) } } + + test("SPARK-49246: saveAsTable with v1 format") { + withTable("t") { + sql("CREATE TABLE t(c INT) USING csv") + val df = spark.range(10).toDF() + df.write.mode(SaveMode.Overwrite).format("csv").saveAsTable("t") + verifyTable("t", df) + } + } } class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index ff944dbb805cb..2254abef3fcb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -22,8 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, DelegatingCatalogExtension, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -53,14 +52,10 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating if (tables.containsKey(ident)) { tables.get(ident) } else { - // Table was created through the built-in catalog - super.loadTable(ident) match { - case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW => v1Table - case t => - val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties()) - addTable(ident, table) - table - } + // Table was created through the built-in catalog via v1 command, this is OK as the + // `loadTable` should always be invoked, and we set the `tableCreated` to pass validation. + tableCreated.set(true) + super.loadTable(ident) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 7c929b5da872a..227280a7626e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -700,7 +700,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val description = "this is a test table" withTable("t") { - withTempDir { dir => + withTempDir { baseDir => + val dir = new File(baseDir, "test%prefix") spark.catalog.createTable( tableName = "t", source = "json", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 6d836884c5d58..4b47529591c04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -678,4 +678,38 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("nested compounds in loop - leave in inner compound") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | BEGIN + | SELECT 1; + | lbl2: BEGIN + | SELECT 2; + | LEAVE lbl2; + | SELECT 3; + | END; + | END; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 9eeedd8598092..cdd4c891d5000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -21,12 +21,16 @@ import java.io.File import java.time.Duration import java.util.UUID -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.time.{Seconds, Span} import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state._ @@ -1426,6 +1430,284 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } } + + private def getFiles(path: Path): Array[FileStatus] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val fileManager = CheckpointFileManager.create(path, hadoopConf) + fileManager.list(path) + } + + private def getStateSchemaPath(stateCheckpointPath: Path): Path = { + new Path(stateCheckpointPath, "_stateSchema/default/") + } + + test("transformWithState - verify that metadata and schema logs are purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + withTempDir { chkptDir => + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result3 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "1", "str2")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + // because we don't change the schema for this run, there won't + // be a new schema file written. + testStream(result3, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "2", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + // by the end of the test, there have been 4 batches, + // so the metadata and schema logs, and commitLog has been purged + // for batches 0 and 1 so metadata and schema files exist for batches 0, 1, 2, 3 + // and we only need to keep metadata files for batches 2, 3, and the since schema + // hasn't changed between batches 2, 3, we only keep the schema file for batch 2 + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) + } + } + } + + test("state data source integration - value state supports time travel") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "5") { + withTempDir { chkptDir => + // in this test case, we are changing the state spec back and forth + // to trigger the writing of the schema and metadata files + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + val result2 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new MostRecentStatefulProcessorWithDeletion(), + TimeMode.None(), + OutputMode.Update()) + testStream(result2, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str2")), + CheckNewAnswer(("a", "str1")), + AddData(inputData, ("a", "str3")), + CheckNewAnswer(("a", "str2")), + AddData(inputData, ("a", "str4")), + CheckNewAnswer(("a", "str3")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + + // Batches 0-7: countState, mostRecent + // Batches 8-9: countState + + // By this time, offset and commit logs for batches 0-3 have been purged. + // However, if we want to read the data for batch 4, we need to read the corresponding + // metadata and schema file for batch 4, or the latest files that correspond to + // batch 4 (in this case, the files were written for batch 0). + // We want to test the behavior where the metadata files are preserved so that we can + // read from the state data source, even if the commit and offset logs are purged for + // a particular batch + + val df = spark.read.format("state-metadata").load(chkptDir.toString) + + // check the min and max batch ids that we have data for + checkAnswer( + df.select( + "operatorId", "operatorName", "stateStoreName", "numPartitions", "minBatchId", + "maxBatchId"), + Seq(Row(0, "transformWithStateExec", "default", 5, 4, 9)) + ) + + val countStateDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val countStateAnsDf = countStateDf.selectExpr( + "key.value AS groupingKey", + "single_value.value AS valueId") + checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) + + val mostRecentDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, chkptDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mostRecent") + .option(StateSourceOptions.BATCH_ID, 4) + .load() + + val mostRecentAnsDf = mostRecentDf.selectExpr( + "key.value AS groupingKey", + "single_value.value") + checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) + } + } + } + + test("transformWithState - verify that all metadata and schema logs are not purged") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "3") { + withTempDir { chkptDir => + val inputData = MemoryStream[(String, String)] + val result1 = inputData.toDS() + .groupByKey(x => x._1) + .transformWithState(new RunningCountMostRecentStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "1", "")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "2", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "3", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "4", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "5", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "6", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "7", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "8", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "9", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "10", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "11", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "12", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + testStream(result1, OutputMode.Update())( + StartStream(checkpointLocation = chkptDir.getCanonicalPath), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "13", "str1")), + AddData(inputData, ("a", "str1")), + CheckNewAnswer(("a", "14", "str1")), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + }, + StopStream + ) + + val stateOpIdPath = new Path(new Path(chkptDir.getCanonicalPath, "state"), "0") + val stateSchemaPath = getStateSchemaPath(stateOpIdPath) + + val metadataPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath) + + // Metadata files exist for batches 0, 12, and the thresholdBatchId is 8 + // as this is the earliest batchId for which the commit log is not present, + // so we need to keep metadata files for batch 0 so we can read the commit + // log correspondingly + assert(getFiles(metadataPath).length == 2) + assert(getFiles(stateSchemaPath).length == 1) + } + } + } } class TransformWithStateValidationSuite extends StateStoreMetricsTest {