Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34897][SQL] Support reconcile schemas based on index after nested column pruning #31993

Closed
wants to merge 9 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import org.apache.spark.sql.types._

object SchemaPruning extends SQLConfHelper {
/**
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
* and given requested field are "a", the field "b" is pruned in the returned schema.
* Note that schema field ordering at original schema is still preserved in pruned schema.
* Prunes the nested schema by the requested fields. For example, if the schema is:
* `id int, s struct<a:int, b:int>`, and given requested field "s.a", the inner field "b"
* is pruned in the returned schema: `id int, s struct<a:int>`.
* Note that:
* 1. The schema field ordering at original schema is still preserved in pruned schema.
* 2. The top-level fields are not pruned here.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
*/
def pruneDataSchema(
dataSchema: StructType,
Expand All @@ -34,11 +37,10 @@ object SchemaPruning extends SQLConfHelper {
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
val mergedSchema = requestedRootFields
.map { case root: RootField => StructType(Array(root.field)) }
.map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,36 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._

class SchemaPruningSuite extends SparkFunSuite with SQLHelper {

def getRootFields(requestedFields: StructField*): Seq[RootField] = {
requestedFields.map { f =>
private def testPrunedSchema(
schema: StructType,
requestedFields: Seq[StructField],
expectedSchema: StructType): Unit = {
val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(prunedSchema === expectedSchema)
}

test("prune schema by the requested fields") {
def testPrunedSchema(
schema: StructType,
requestedFields: StructField*): Unit = {
val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(expectedSchema == StructType(requestedFields))
}

testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))
testPrunedSchema(
StructType.fromDDL("a int, b int"),
Seq(StructField("a", IntegerType)),
StructType.fromDDL("a int, b int"))

val structOfStruct = StructType.fromDDL("a struct<a:int, b:int>, b int")
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int")))
testPrunedSchema(structOfStruct, StructField("b", IntegerType))
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int")))
testPrunedSchema(structOfStruct,
Seq(StructField("a", StructType.fromDDL("a int")), StructField("b", IntegerType)),
StructType.fromDDL("a struct<a:int>, b int"))
testPrunedSchema(structOfStruct,
Seq(StructField("a", StructType.fromDDL("a int"))),
StructType.fromDDL("a struct<a:int>, b int"))

val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string")))
val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"),
Expand All @@ -60,44 +57,76 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) ::
mapOfStruct :: Nil)

testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))),
StructField("b", StructType.fromDDL("a int")))
testPrunedSchema(complexStruct,
StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
StructField("b", StructType.fromDDL("b int")))
Seq(StructField("a", ArrayType(StructType.fromDDL("b int"))),
StructField("b", StructType.fromDDL("a int"))),
StructType(
StructField("a", ArrayType(StructType.fromDDL("b int"))) ::
StructField("b", StructType.fromDDL("a int")) ::
StructField("c", IntegerType) ::
mapOfStruct :: Nil))
testPrunedSchema(complexStruct,
Seq(StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
StructField("b", StructType.fromDDL("b int"))),
StructType(
StructField("a", ArrayType(StructType.fromDDL("b int, c string"))) ::
StructField("b", StructType.fromDDL("b int")) ::
StructField("c", IntegerType) ::
mapOfStruct :: Nil))

val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"),
StructType.fromDDL("e int, f string")))
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
testPrunedSchema(complexStruct,
Seq(StructField("c", IntegerType), selectFieldInMap),
StructType(
arrayOfStruct ::
StructField("b", structOfStruct) ::
StructField("c", IntegerType) ::
selectFieldInMap :: Nil))
}

test("SPARK-35096: test case insensitivity of pruned schema") {
Seq(true, false).foreach(isCaseSensitive => {
val upperCaseSchema = StructType.fromDDL("A struct<A:int, B:int>, B int")
val lowerCaseSchema = StructType.fromDDL("a struct<a:int, b:int>, b int")
val upperCaseRequestedFields = Seq(StructField("A", StructType.fromDDL("A int")))
val lowerCaseRequestedFields = Seq(StructField("a", StructType.fromDDL("a int")))

Seq(true, false).foreach { isCaseSensitive =>
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
if (isCaseSensitive) {
// Schema is case-sensitive
val requestedFields = getRootFields(StructField("id", IntegerType))
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"), requestedFields)
assert(prunedSchema == StructType(Seq.empty))
// Root fields are case-sensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
testPrunedSchema(
upperCaseSchema,
upperCaseRequestedFields,
StructType.fromDDL("A struct<A:int>, B int"))
testPrunedSchema(
upperCaseSchema,
lowerCaseRequestedFields,
upperCaseSchema)

testPrunedSchema(
lowerCaseSchema,
upperCaseRequestedFields,
lowerCaseSchema)
testPrunedSchema(
lowerCaseSchema,
lowerCaseRequestedFields,
StructType.fromDDL("a struct<a:int>, b int"))
} else {
// Schema is case-insensitive
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"),
getRootFields(StructField("id", IntegerType)))
assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
// Root fields are case-insensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
testPrunedSchema(
upperCaseSchema,
requestedFields,
StructType.fromDDL("A struct<A:int>, B int"))
}

Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
testPrunedSchema(
lowerCaseSchema,
requestedFields,
StructType.fromDDL("a struct<a:int>, b int"))
}
}
}
})
}
Comment on lines +89 to +130
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests LGTM, thanks for add more scenarios

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import org.apache.spark.sql.types.{StructField, StructType}
case class HadoopFsRelation(
location: FileIndex,
partitionSchema: StructType,
// The top-level columns in `dataSchema` should match the actual physical file schema, otherwise
// the ORC data source may not work with the by-ordinal mode.
dataSchema: StructType,
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ object PushDownUtils extends PredicateHelper {
relation: DataSourceV2Relation,
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = {
val exprs = projects ++ filters
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)

scanBuilder match {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
Expand All @@ -89,14 +93,12 @@ object PushDownUtils extends PredicateHelper {
} else {
new StructType()
}
r.pruneColumns(prunedSchema)
val neededFieldNames = neededOutput.map(_.name).toSet
r.pruneColumns(StructType(prunedSchema.filter(f => neededFieldNames.contains(f.name))))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move filter logical from SchemaPruning to PushDownUtils to support datasource V2 column pruning.

val scan = r.build()
scan -> toOutputAttrs(scan.readSchema(), relation)

case r: SupportsPushDownRequiredColumns =>
val exprs = projects ++ filters
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)
r.pruneColumns(neededOutput.toStructType)
val scan = r.build()
// always project, in case the relation's output has been updated and doesn't match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,4 +633,20 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession {
}
}
}

test("SPARK-34897: Support reconcile schemas based on index after nested column pruning") {
withTable("t1") {
spark.sql(
"""
|CREATE TABLE t1 (
| _col0 INT,
| _col1 STRING,
| _col2 STRUCT<c1: STRING, c2: STRING, c3: STRING, c4: BIGINT>)
|USING ORC
|""".stripMargin)

spark.sql("INSERT INTO t1 values(1, '2', struct('a', 'b', 'c', 10L))")
checkAnswer(spark.sql("SELECT _col0, _col2.c1 FROM t1"), Seq(Row(1, "a")))
}
}
}