From f6a60c1d0d0cd28638e670367e6f1accb73a0238 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 5 Nov 2020 14:09:02 -0800 Subject: [PATCH 1/3] wip --- .../spark/sql/execution/CacheManager.scala | 4 ++ .../datasources/v2/DataSourceV2Strategy.scala | 9 +++-- .../v2/WriteToDataSourceV2Exec.scala | 39 +++++++++++++++++-- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f163d85914bc9..abb1d7b121457 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -188,6 +188,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { recacheByCondition(spark, _.plan.find(_.sameResult(plan)).isDefined) } + def recacheByExactPlan(spark: SparkSession, plan: LogicalPlan): Unit = { + recacheByCondition(spark, _.plan.sameResult(plan)) + } + /** * Re-caches all the cache entries that satisfies the given `condition`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 21abfc2816ee4..adfa576043051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,6 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( + session, staging, ident, parts, @@ -157,6 +158,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil case _ => ReplaceTableAsSelectExec( + session, catalog, ident, parts, @@ -172,7 +174,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil + AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil } case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => @@ -186,12 +188,13 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil + OverwriteByExpressionExec(session, v2, r, filters, + writeOptions.asOptions, planLater(query)) :: Nil } case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil + session, r.table.asWritable, r, writeOptions.asOptions, planLater(query)) :: Nil case DeleteFromTable(relation, condition) => relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1421a9315c3a8..f418a21ad79b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -127,6 +128,7 @@ case class AtomicCreateTableAsSelectExec( * ReplaceTableAsSelectStagingExec. */ case class ReplaceTableAsSelectExec( + session: SparkSession, catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -146,6 +148,8 @@ case class ReplaceTableAsSelectExec( // 2. Writing to the new table fails, // 3. The table returned by catalog.createTable doesn't support writing. if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) catalog.dropTable(ident) } else if (!orCreate) { throw new CannotReplaceMissingTableException(ident) @@ -169,6 +173,7 @@ case class ReplaceTableAsSelectExec( * is left untouched. */ case class AtomicReplaceTableAsSelectExec( + session: SparkSession, catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], @@ -180,6 +185,10 @@ case class AtomicReplaceTableAsSelectExec( override protected def run(): Seq[InternalRow] = { val schema = query.schema.asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + uncacheTable(session, catalog, table, ident) + } val staged = if (orCreate) { catalog.stageCreateOrReplace( ident, schema, partitioning.toArray, properties.asJava) @@ -204,12 +213,16 @@ case class AtomicReplaceTableAsSelectExec( * Rows in the output data set are appended. */ case class AppendDataExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - writeWithV2(newWriteBuilder().buildForBatch()) + val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -224,7 +237,9 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -234,7 +249,7 @@ case class OverwriteByExpressionExec( } override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => writeWithV2(builder.truncate().buildForBatch()) @@ -244,9 +259,12 @@ case class OverwriteByExpressionExec( case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } + /** * Physical plan node for dynamic partition overwrite into a v2 table. * @@ -257,18 +275,22 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( + session: SparkSession, table: SupportsWrite, + relation: DataSourceV2Relation, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def run(): Seq[InternalRow] = { - newWriteBuilder() match { + val writtenRows = newWriteBuilder() match { case builder: SupportsDynamicOverwrite => writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } + session.sharedState.cacheManager.recacheByPlan(session, relation) + writtenRows } } @@ -370,6 +392,16 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { Nil } + + protected def uncacheTable( + session: SparkSession, + catalog: TableCatalog, + table: Table, + ident: Identifier): Unit = { + // TODO: what about options? do they matter when comparing? + val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) + } } object DataWritingSparkTask extends Logging { @@ -484,3 +516,4 @@ private[v2] case class DataWritingSparkTaskResult( * Sink progress information collected after commit. */ private[sql] case class StreamWriterCommitProgress(numOutputRows: Long) + From 01ce2de811c0b7d542fbcfb4a8093e0ffda82716 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 18 Nov 2020 16:29:05 -0800 Subject: [PATCH 2/3] add tests for v2 --- .../sql/connector/DataSourceV2SQLSuite.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index ddafa1bb5070a..cf3d321d30fa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -782,6 +782,80 @@ class DataSourceV2SQLSuite } } + test("SPARK-XXXXX: ReplaceTableAsSelect (atomic or non-atomic) should also invalidate cache") { + Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"REPLACE TABLE $t USING foo AS SELECT id FROM source") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isEmpty) + } + } + } + } + + test("SPARK-XXXXX: OverwriteByExpression should also refresh cache") { + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) + + sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + } + } + } + + test("SPARK-XXXXX: AppendTable should also invalidate cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) + } + } + } + + test("SPARK-XXXXX: OverwriteParttiionsDynamic should invalidate cache") { + import testImplicits._ + + val t = "testcat.ns.t" + val view = "view" + withTable(t) { + withTempView(view) { + Seq((1, "a", 1)).toDF("i", "j", "k").write.partitionBy("k") saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions() + + checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil) + } + } + } + test("Relation: basic") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { From ef1a45ba9638c21b04ce0946752f90bfdde3d824 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 18 Nov 2020 23:04:16 -0800 Subject: [PATCH 3/3] handle v1 fallback --- .../spark/sql/execution/CacheManager.scala | 4 - .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../datasources/v2/V1FallbackWriters.scala | 21 +++-- .../v2/WriteToDataSourceV2Exec.scala | 1 - .../sql/connector/DataSourceV2SQLSuite.scala | 42 +++++----- .../sql/connector/V1WriteFallbackSuite.scala | 79 ++++++++++++++++++- 6 files changed, 116 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index abb1d7b121457..f163d85914bc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -188,10 +188,6 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { recacheByCondition(spark, _.plan.find(_.sameResult(plan)).isDefined) } - def recacheByExactPlan(spark: SparkSession, plan: LogicalPlan): Unit = { - recacheByCondition(spark, _.plan.sameResult(plan)) - } - /** * Re-caches all the cache entries that satisfies the given `condition`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index adfa576043051..e5c29312b80e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -172,7 +172,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil + AppendDataExecV1(v1, writeOptions.asOptions, query, r) :: Nil case v2 => AppendDataExec(session, v2, r, writeOptions.asOptions, planLater(query)) :: Nil } @@ -186,7 +186,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil + OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query, r) :: Nil case v2 => OverwriteByExpressionExec(session, v2, r, filters, writeOptions.asOptions, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 560da39314b36..af7721588edeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -37,10 +37,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AppendDataExecV1( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { override protected def run(): Seq[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write()) + writeWithV1(newWriteBuilder().buildForV1Write(), Some(v2Relation)) } } @@ -59,7 +60,8 @@ case class OverwriteByExpressionExecV1( table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + plan: LogicalPlan, + v2Relation: DataSourceV2Relation) extends V1FallbackWriters { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] @@ -68,10 +70,10 @@ case class OverwriteByExpressionExecV1( override protected def run(): Seq[InternalRow] = { newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) + writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), Some(v2Relation)) case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), Some(v2Relation)) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") @@ -112,9 +114,14 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { + protected def writeWithV1( + relation: InsertableRelation, + v2Relation: Option[DataSourceV2Relation] = None): Seq[InternalRow] = { + val session = sqlContext.sparkSession // The `plan` is already optimized, we should not analyze and optimize it again. - relation.insert(AlreadyOptimized.dataFrame(sqlContext.sparkSession, plan), overwrite = false) + relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) + v2Relation.foreach(r => session.sharedState.cacheManager.recacheByPlan(session, r)) + Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index f418a21ad79b3..1648134d0a1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -398,7 +398,6 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { catalog: TableCatalog, table: Table, ident: Identifier): Unit = { - // TODO: what about options? do they matter when comparing? val plan = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) session.sharedState.cacheManager.uncacheQuery(session, plan, cascade = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index cf3d321d30fa6..a92fdb14cd46f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -782,7 +782,7 @@ class DataSourceV2SQLSuite } } - test("SPARK-XXXXX: ReplaceTableAsSelect (atomic or non-atomic) should also invalidate cache") { + test("SPARK-33492: ReplaceTableAsSelect (atomic or non-atomic) should invalidate cache") { Seq("testcat.ns.t", "testcat_atomic.ns.t").foreach { t => val view = "view" withTable(t) { @@ -799,44 +799,47 @@ class DataSourceV2SQLSuite } } - test("SPARK-XXXXX: OverwriteByExpression should also refresh cache") { + test("SPARK-33492: AppendData should refresh cache") { + import testImplicits._ + val t = "testcat.ns.t" val view = "view" withTable(t) { withTempView(view) { - sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") - sql(s"CACHE TABLE $view AS SELECT id FROM $t") - checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) - checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) - - sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") + Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) + sql(s"CACHE TABLE $view AS SELECT i FROM $t") checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + + Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) } } } - test("SPARK-XXXXX: AppendTable should also invalidate cache") { - import testImplicits._ - + test("SPARK-33492: OverwriteByExpression should refresh cache") { val t = "testcat.ns.t" val view = "view" withTable(t) { withTempView(view) { - Seq((1, "a")).toDF("i", "j").write.saveAsTable(t) - sql(s"CACHE TABLE $view AS SELECT i FROM $t") - checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) - checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) + sql(s"CREATE TABLE $t USING foo AS SELECT id, data FROM source") + sql(s"CACHE TABLE $view AS SELECT id FROM $t") + checkAnswer(sql(s"SELECT * FROM $t"), spark.table("source")) + checkAnswer(sql(s"SELECT * FROM $view"), spark.table("source").select("id")) - Seq((2, "b")).toDF("i", "j").write.mode(SaveMode.Append).saveAsTable(t) + sql(s"INSERT OVERWRITE TABLE $t VALUES (1, 'a')") - checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Row(2, "b") :: Nil) - checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Row(2) :: Nil) + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) + checkAnswer(sql(s"SELECT * FROM $t"), Row(1, "a") :: Nil) + checkAnswer(sql(s"SELECT * FROM $view"), Row(1) :: Nil) } } } - test("SPARK-XXXXX: OverwriteParttiionsDynamic should invalidate cache") { + test("SPARK-33492: OverwritePartitionsDynamic should refresh cache") { import testImplicits._ val t = "testcat.ns.t" @@ -850,6 +853,7 @@ class DataSourceV2SQLSuite Seq((2, "b", 1)).toDF("i", "j", "k").writeTo(t).overwritePartitions() + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(view)).isDefined) checkAnswer(sql(s"SELECT * FROM $t"), Row(2, "b", 1) :: Nil) checkAnswer(sql(s"SELECT * FROM $view"), Row(2) :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 4b52a4cbf4116..cba7dd35fb3bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,14 +24,17 @@ import scala.collection.mutable import org.scalatest.BeforeAndAfter +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan} import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources._ @@ -145,6 +148,52 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before SparkSession.setDefaultSession(spark) } } + + test("SPARK-33492: append fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").append() + checkAnswer(session.read.table("test"), Row(1, "x") :: Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } + + test("SPARK-33492: overwrite fallback should refresh cache") { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + try { + val session = SparkSession.builder() + .master("local[1]") + .config(V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[V1FallbackTableCatalog].getName) + .getOrCreate() + val df = session.createDataFrame(Seq((1, "x"))) + df.write.mode("append").option("name", "t1").format(v2Format).saveAsTable("test") + session.catalog.cacheTable("test") + checkAnswer(session.read.table("test"), Row(1, "x") :: Nil) + + val df2 = session.createDataFrame(Seq((2, "y"))) + df2.writeTo("test").overwrite(lit(true)) + checkAnswer(session.read.table("test"), Row(2, "y") :: Nil) + + } finally { + SparkSession.setActiveSession(spark) + SparkSession.setDefaultSession(spark) + } + } } class V1WriteFallbackSessionCatalogSuite @@ -177,6 +226,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) + tables.put(Identifier.of(Array("default"), name), t) t } } @@ -272,7 +322,7 @@ class InMemoryTableWithV1Fallback( override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table - with SupportsWrite { + with SupportsWrite with SupportsRead { partitioning.foreach { t => if (!t.isInstanceOf[IdentityTransform]) { @@ -281,6 +331,7 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava @@ -338,6 +389,30 @@ class InMemoryTableWithV1Fallback( } } } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = + new V1ReadFallbackScanBuilder(schema) + + private class V1ReadFallbackScanBuilder(schema: StructType) extends ScanBuilder { + override def build(): Scan = new V1ReadFallbackScan(schema) + } + + private class V1ReadFallbackScan(schema: StructType) extends V1Scan { + override def readSchema(): StructType = schema + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = + new V1TableScan(context, schema).asInstanceOf[T] + } + + private class V1TableScan( + context: SQLContext, + requiredSchema: StructType) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val data = InMemoryV1Provider.getTableData(context.sparkSession, name).collect() + context.sparkContext.makeRDD(data) + } + } } /** A rule that fails if a query plan is analyzed twice. */