From d0cd56d768dff336913f1a71d71c5dde91a23308 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sat, 6 May 2023 05:35:03 +0800 Subject: [PATCH 1/3] [SPARK-43390][SQL] DSv2 allows CTAS/RTAS to reserve schema nullability --- .../sql/connector/catalog/TableCatalog.java | 8 ++++ .../v2/WriteToDataSourceV2Exec.scala | 18 +++++---- .../sql/connector/DataSourceV2SQLSuite.scala | 39 ++++++++++++++----- .../sql/connector/DatasourceV2SQLBase.scala | 21 +++++----- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index eb442ad38bde5..14f6d26a9c7d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -199,6 +199,14 @@ default Table createTable( return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties); } + /** + * Return whether to reserve schema nullability of query output or forcibly use nullable schema + * on creating table implicitly, e.g. CTAS/RTAS. + */ + default boolean createTableReserveSchemaNullability() { + return false; + } + /** * Apply a set of {@link TableChange changes} to a table in the catalog. *

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 426f33129a6ac..f44fd0cdbac06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -82,7 +82,8 @@ case class CreateTableAsSelectExec( throw QueryCompilationErrors.tableAlreadyExistsError(ident) } val table = catalog.createTable( - ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava) + ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -115,7 +116,8 @@ case class AtomicCreateTableAsSelectExec( throw QueryCompilationErrors.tableAlreadyExistsError(ident) } val stagedTable = catalog.stageCreate( - ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava) + ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + partitioning.toArray, properties.asJava) writeToTable(catalog, stagedTable, writeOptions, ident, query) } } @@ -160,7 +162,8 @@ case class ReplaceTableAsSelectExec( throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } val table = catalog.createTable( - ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava) + ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident, query) } } @@ -191,7 +194,7 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) override protected def run(): Seq[InternalRow] = { - val columns = getV2Columns(query.schema) + val columns = getV2Columns(query.schema, catalog.createTableReserveSchemaNullability) if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) invalidateCache(catalog, table, ident) @@ -555,9 +558,10 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - protected def getV2Columns(schema: StructType): Array[Column] = { - CatalogV2Util.structTypeToV2Columns(CharVarcharUtils.getRawSchema( - removeInternalMetadata(schema), conf).asNullable) + protected def getV2Columns(schema: StructType, reserveNullability: Boolean): Array[Column] = { + val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) + val tableSchema = if (reserveNullability) rawSchema else rawSchema.asNullable + CatalogV2Util.structTypeToV2Columns(tableSchema) } protected def writeToTable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 053afb84c103a..bb844c852471c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -781,13 +781,20 @@ class DataSourceV2SQLSuiteV1Filter } test("CreateTableAsSelect: nullable schema") { + registerCatalog("testcat_nullability", classOf[TestCreateTableReserveSchemaNullabilityCatalog]) + val basicCatalog = catalog("testcat").asTableCatalog val atomicCatalog = catalog("testcat_atomic").asTableCatalog + val reserveNullabilityCatalog = catalog("testcat_nullability").asTableCatalog val basicIdentifier = "testcat.table_name" val atomicIdentifier = "testcat_atomic.table_name" + val reserveNullabilityIdentifier = "testcat_nullability.table_name" - Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach { - case (catalog, identifier) => + Seq( + (basicCatalog, basicIdentifier, true), + (atomicCatalog, atomicIdentifier, true), + (reserveNullabilityCatalog, reserveNullabilityIdentifier, false)).foreach { + case (catalog, identifier, nullable) => spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT 1 i") val table = catalog.loadTable(Identifier.of(Array(), "table_name")) @@ -795,14 +802,24 @@ class DataSourceV2SQLSuiteV1Filter assert(table.name == identifier) assert(table.partitioning.isEmpty) assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava) - assert(table.schema == new StructType().add("i", "int")) + assert(table.schema == new StructType().add("i", "int", nullable)) val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Row(1)) - sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)") - val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) - checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null))) + def insertNullValueAndCheck(): Unit = { + sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)") + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null))) + } + if (nullable) { + insertNullValueAndCheck() + } else { + val e = intercept[Exception] { + insertNullValueAndCheck() + } + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + } } } @@ -2311,7 +2328,7 @@ class DataSourceV2SQLSuiteV1Filter test("global temp view should not be masked by v2 catalog") { val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) try { sql("create global temp view v as select 1") @@ -2336,7 +2353,7 @@ class DataSourceV2SQLSuiteV1Filter test("SPARK-30104: v2 catalog named global_temp will be masked") { val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) checkError( exception = intercept[AnalysisException] { // Since the following multi-part name starts with `globalTempDB`, it is resolved to @@ -2543,7 +2560,7 @@ class DataSourceV2SQLSuiteV1Filter context = ExpectedContext(fragment = "testcat.abc", start = 17, stop = 27)) val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE) - spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName) + registerCatalog(globalTempDB, classOf[InMemoryTableCatalog]) withTempView("v") { sql("create global temp view v as select 1") checkError( @@ -3261,3 +3278,7 @@ class FakeV2Provider extends SimpleTableProvider { throw new UnsupportedOperationException("Unnecessary for DDL tests") } } + +class TestCreateTableReserveSchemaNullabilityCatalog extends InMemoryCatalog { + override def createTableReserveSchemaNullability(): Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala index 9c30851502518..4ccff44fa0674 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DatasourceV2SQLBase.scala @@ -21,26 +21,27 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.QueryTest import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, InMemoryTableWithV2FilterCatalog, StagingInMemoryTableCatalog} -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.test.SharedSparkSession trait DatasourceV2SQLBase extends QueryTest with SharedSparkSession with BeforeAndAfter { + protected def registerCatalog[T <: CatalogPlugin](name: String, clazz: Class[T]): Unit = { + spark.conf.set(s"spark.sql.catalog.$name", clazz.getName) + } + protected def catalog(name: String): CatalogPlugin = { spark.sessionState.catalogManager.catalog(name) } before { - spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName) - spark.conf.set("spark.sql.catalog.testv2filter", - classOf[InMemoryTableWithV2FilterCatalog].getName) - spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName) - spark.conf.set( - "spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName) - spark.conf.set( - V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) + registerCatalog("testcat", classOf[InMemoryCatalog]) + registerCatalog("testv2filter", classOf[InMemoryTableWithV2FilterCatalog]) + registerCatalog("testpart", classOf[InMemoryPartitionTableCatalog]) + registerCatalog("testcat_atomic", classOf[StagingInMemoryTableCatalog]) + registerCatalog("testcat2", classOf[InMemoryCatalog]) + registerCatalog(SESSION_CATALOG_NAME, classOf[InMemoryTableSessionCatalog]) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") From 080c0a1235937b59fbc48346cef986527dc2a103 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 8 May 2023 13:02:05 +0800 Subject: [PATCH 2/3] rename to useNullableQuerySchema --- .../spark/sql/connector/catalog/TableCatalog.java | 8 ++++---- .../datasources/v2/WriteToDataSourceV2Exec.scala | 12 ++++++------ .../spark/sql/connector/DataSourceV2SQLSuite.scala | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index 14f6d26a9c7d5..6cfd5ab1b6bee 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -200,11 +200,11 @@ default Table createTable( } /** - * Return whether to reserve schema nullability of query output or forcibly use nullable schema - * on creating table implicitly, e.g. CTAS/RTAS. + * If true, mark all the fields of the query schema as nullable when executing + * CREATE/REPLACE TABLE ... AS SELECT ... and creating the table. */ - default boolean createTableReserveSchemaNullability() { - return false; + default boolean useNullableQuerySchema() { + return true; } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index f44fd0cdbac06..4a9b85450a176 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -82,7 +82,7 @@ case class CreateTableAsSelectExec( throw QueryCompilationErrors.tableAlreadyExistsError(ident) } val table = catalog.createTable( - ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident, query) } @@ -116,7 +116,7 @@ case class AtomicCreateTableAsSelectExec( throw QueryCompilationErrors.tableAlreadyExistsError(ident) } val stagedTable = catalog.stageCreate( - ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) writeToTable(catalog, stagedTable, writeOptions, ident, query) } @@ -162,7 +162,7 @@ case class ReplaceTableAsSelectExec( throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } val table = catalog.createTable( - ident, getV2Columns(query.schema, catalog.createTableReserveSchemaNullability), + ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), partitioning.toArray, properties.asJava) writeToTable(catalog, table, writeOptions, ident, query) } @@ -194,7 +194,7 @@ case class AtomicReplaceTableAsSelectExec( val properties = CatalogV2Util.convertTableProperties(tableSpec) override protected def run(): Seq[InternalRow] = { - val columns = getV2Columns(query.schema, catalog.createTableReserveSchemaNullability) + val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) invalidateCache(catalog, table, ident) @@ -558,9 +558,9 @@ case class DeltaWithMetadataWritingSparkTask( private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { override def output: Seq[Attribute] = Nil - protected def getV2Columns(schema: StructType, reserveNullability: Boolean): Array[Column] = { + protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) - val tableSchema = if (reserveNullability) rawSchema else rawSchema.asNullable + val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema CatalogV2Util.structTypeToV2Columns(tableSchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index bb844c852471c..d136b68aee66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3280,5 +3280,5 @@ class FakeV2Provider extends SimpleTableProvider { } class TestCreateTableReserveSchemaNullabilityCatalog extends InMemoryCatalog { - override def createTableReserveSchemaNullability(): Boolean = true + override def useNullableQuerySchema(): Boolean = false } From 13e8b75fcca41928bb1ff8cee603f8841433e467 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 8 May 2023 17:37:49 +0800 Subject: [PATCH 3/3] nit --- .../org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d136b68aee66c..6f14b0971caa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -781,7 +781,7 @@ class DataSourceV2SQLSuiteV1Filter } test("CreateTableAsSelect: nullable schema") { - registerCatalog("testcat_nullability", classOf[TestCreateTableReserveSchemaNullabilityCatalog]) + registerCatalog("testcat_nullability", classOf[ReserveSchemaNullabilityCatalog]) val basicCatalog = catalog("testcat").asTableCatalog val atomicCatalog = catalog("testcat_atomic").asTableCatalog @@ -3279,6 +3279,6 @@ class FakeV2Provider extends SimpleTableProvider { } } -class TestCreateTableReserveSchemaNullabilityCatalog extends InMemoryCatalog { +class ReserveSchemaNullabilityCatalog extends InMemoryCatalog { override def useNullableQuerySchema(): Boolean = false }