Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43390][SQL] DSv2 allows CTAS/RTAS to reserve schema nullability #41070

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,14 @@ default Table createTable(
return createTable(ident, CatalogV2Util.v2ColumnsToStructType(columns), partitions, properties);
}

/**
* If true, mark all the fields of the query schema as nullable when executing
* CREATE/REPLACE TABLE ... AS SELECT ... and creating the table.
*/
default boolean useNullableQuerySchema() {
return true;
}

/**
* Apply a set of {@link TableChange changes} to a table in the catalog.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ case class CreateTableAsSelectExec(
throw QueryCompilationErrors.tableAlreadyExistsError(ident)
}
val table = catalog.createTable(
ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava)
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
partitioning.toArray, properties.asJava)
writeToTable(catalog, table, writeOptions, ident, query)
}
}
Expand Down Expand Up @@ -115,7 +116,8 @@ case class AtomicCreateTableAsSelectExec(
throw QueryCompilationErrors.tableAlreadyExistsError(ident)
}
val stagedTable = catalog.stageCreate(
ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava)
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
partitioning.toArray, properties.asJava)
writeToTable(catalog, stagedTable, writeOptions, ident, query)
}
}
Expand Down Expand Up @@ -160,7 +162,8 @@ case class ReplaceTableAsSelectExec(
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
}
val table = catalog.createTable(
ident, getV2Columns(query.schema), partitioning.toArray, properties.asJava)
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
partitioning.toArray, properties.asJava)
writeToTable(catalog, table, writeOptions, ident, query)
}
}
Expand Down Expand Up @@ -191,7 +194,7 @@ case class AtomicReplaceTableAsSelectExec(
val properties = CatalogV2Util.convertTableProperties(tableSpec)

override protected def run(): Seq[InternalRow] = {
val columns = getV2Columns(query.schema)
val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema)
if (catalog.tableExists(ident)) {
val table = catalog.loadTable(ident)
invalidateCache(catalog, table, ident)
Expand Down Expand Up @@ -555,9 +558,10 @@ case class DeltaWithMetadataWritingSparkTask(
private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
override def output: Seq[Attribute] = Nil

protected def getV2Columns(schema: StructType): Array[Column] = {
CatalogV2Util.structTypeToV2Columns(CharVarcharUtils.getRawSchema(
removeInternalMetadata(schema), conf).asNullable)
protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = {
val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf)
val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema
CatalogV2Util.structTypeToV2Columns(tableSchema)
}

protected def writeToTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,28 +781,45 @@ class DataSourceV2SQLSuiteV1Filter
}

test("CreateTableAsSelect: nullable schema") {
registerCatalog("testcat_nullability", classOf[ReserveSchemaNullabilityCatalog])

val basicCatalog = catalog("testcat").asTableCatalog
val atomicCatalog = catalog("testcat_atomic").asTableCatalog
val reserveNullabilityCatalog = catalog("testcat_nullability").asTableCatalog
val basicIdentifier = "testcat.table_name"
val atomicIdentifier = "testcat_atomic.table_name"
val reserveNullabilityIdentifier = "testcat_nullability.table_name"

Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach {
case (catalog, identifier) =>
Seq(
(basicCatalog, basicIdentifier, true),
(atomicCatalog, atomicIdentifier, true),
(reserveNullabilityCatalog, reserveNullabilityIdentifier, false)).foreach {
case (catalog, identifier, nullable) =>
spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT 1 i")

val table = catalog.loadTable(Identifier.of(Array(), "table_name"))

assert(table.name == identifier)
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
assert(table.schema == new StructType().add("i", "int"))
assert(table.schema == new StructType().add("i", "int", nullable))

val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Row(1))

sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)")
val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null)))
def insertNullValueAndCheck(): Unit = {
sql(s"INSERT INTO $identifier SELECT CAST(null AS INT)")
val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq(Row(1), Row(null)))
}
if (nullable) {
insertNullValueAndCheck()
} else {
val e = intercept[Exception] {
insertNullValueAndCheck()
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
}
}

Expand Down Expand Up @@ -2311,7 +2328,7 @@ class DataSourceV2SQLSuiteV1Filter

test("global temp view should not be masked by v2 catalog") {
val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName)
registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])

try {
sql("create global temp view v as select 1")
Expand All @@ -2336,7 +2353,7 @@ class DataSourceV2SQLSuiteV1Filter

test("SPARK-30104: v2 catalog named global_temp will be masked") {
val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName)
registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])
checkError(
exception = intercept[AnalysisException] {
// Since the following multi-part name starts with `globalTempDB`, it is resolved to
Expand Down Expand Up @@ -2543,7 +2560,7 @@ class DataSourceV2SQLSuiteV1Filter
context = ExpectedContext(fragment = "testcat.abc", start = 17, stop = 27))

val globalTempDB = spark.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
spark.conf.set(s"spark.sql.catalog.$globalTempDB", classOf[InMemoryTableCatalog].getName)
registerCatalog(globalTempDB, classOf[InMemoryTableCatalog])
withTempView("v") {
sql("create global temp view v as select 1")
checkError(
Expand Down Expand Up @@ -3261,3 +3278,7 @@ class FakeV2Provider extends SimpleTableProvider {
throw new UnsupportedOperationException("Unnecessary for DDL tests")
}
}

class ReserveSchemaNullabilityCatalog extends InMemoryCatalog {
override def useNullableQuerySchema(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,27 @@ import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, InMemoryCatalog, InMemoryPartitionTableCatalog, InMemoryTableWithV2FilterCatalog, StagingInMemoryTableCatalog}
import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.test.SharedSparkSession

trait DatasourceV2SQLBase
extends QueryTest with SharedSparkSession with BeforeAndAfter {

protected def registerCatalog[T <: CatalogPlugin](name: String, clazz: Class[T]): Unit = {
spark.conf.set(s"spark.sql.catalog.$name", clazz.getName)
}

protected def catalog(name: String): CatalogPlugin = {
spark.sessionState.catalogManager.catalog(name)
}

before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryCatalog].getName)
spark.conf.set("spark.sql.catalog.testv2filter",
classOf[InMemoryTableWithV2FilterCatalog].getName)
spark.conf.set("spark.sql.catalog.testpart", classOf[InMemoryPartitionTableCatalog].getName)
spark.conf.set(
"spark.sql.catalog.testcat_atomic", classOf[StagingInMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryCatalog].getName)
spark.conf.set(
V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName)
registerCatalog("testcat", classOf[InMemoryCatalog])
registerCatalog("testv2filter", classOf[InMemoryTableWithV2FilterCatalog])
registerCatalog("testpart", classOf[InMemoryPartitionTableCatalog])
registerCatalog("testcat_atomic", classOf[StagingInMemoryTableCatalog])
registerCatalog("testcat2", classOf[InMemoryCatalog])
registerCatalog(SESSION_CATALOG_NAME, classOf[InMemoryTableSessionCatalog])

val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
df.createOrReplaceTempView("source")
Expand Down