Skip to content

Commit

Permalink
[SPARK-34897][SQL] Support reconcile schemas based on index after nes…
Browse files Browse the repository at this point in the history
…ted column pruning

### What changes were proposed in this pull request?

It will remove `StructField` when [pruning nested columns](https://github.com/apache/spark/blob/0f2c0b53e8fb18c86c67b5dd679c006db93f94a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala#L28-L42). For example:
```scala
spark.sql(
  """
    |CREATE TABLE t1 (
    |  _col0 INT,
    |  _col1 STRING,
    |  _col2 STRUCT<c1: STRING, c2: STRING, c3: STRING, c4: BIGINT>)
    |USING ORC
    |""".stripMargin)

spark.sql("INSERT INTO t1 values(1, '2', struct('a', 'b', 'c', 10L))")

spark.sql("SELECT _col0, _col2.c1 FROM t1").show
```

Before this pr. The returned schema is: ``` `_col0` INT,`_col2` STRUCT<`c1`: STRING> ``` add it will throw exception:
```
java.lang.AssertionError: assertion failed: The given data schema struct<_col0:int,_col2:struct<c1:string>> has less fields than the actual ORC physical schema, no idea which columns were dropped, fail to read.
	at scala.Predef$.assert(Predef.scala:223)
	at org.apache.spark.sql.execution.datasources.orc.OrcUtils$.requestedColumnIds(OrcUtils.scala:160)
```

After this pr. The returned schema is: ``` `_col0` INT,`_col1` STRING,`_col2` STRUCT<`c1`: STRING> ```.

The finally schema is ``` `_col0` INT,`_col2` STRUCT<`c1`: STRING> ``` after the complete column pruning:
https://github.com/apache/spark/blob/7a5647a93aaea9d1d78d9262e24fc8c010db04d0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala#L208-L213

https://github.com/apache/spark/blob/e64eb75aede71a5403a4d4436e63b1fcfdeca14d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala#L96-L97

### Why are the changes needed?

Fix bug.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

Closes #31993 from wangyum/SPARK-34897.

Authored-by: Yuming Wang <[email protected]>
Signed-off-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
wangyum authored and viirya committed Apr 21, 2021
1 parent 81dbaed commit e609395
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import org.apache.spark.sql.types._

object SchemaPruning extends SQLConfHelper {
/**
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
* and given requested field are "a", the field "b" is pruned in the returned schema.
* Note that schema field ordering at original schema is still preserved in pruned schema.
* Prunes the nested schema by the requested fields. For example, if the schema is:
* `id int, s struct<a:int, b:int>`, and given requested field "s.a", the inner field "b"
* is pruned in the returned schema: `id int, s struct<a:int>`.
* Note that:
* 1. The schema field ordering at original schema is still preserved in pruned schema.
* 2. The top-level fields are not pruned here.
*/
def pruneDataSchema(
dataSchema: StructType,
Expand All @@ -34,11 +37,10 @@ object SchemaPruning extends SQLConfHelper {
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
val mergedSchema = requestedRootFields
.map { case root: RootField => StructType(Array(root.field)) }
.map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,36 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
import org.apache.spark.sql.types._

class SchemaPruningSuite extends SparkFunSuite with SQLHelper {

def getRootFields(requestedFields: StructField*): Seq[RootField] = {
requestedFields.map { f =>
private def testPrunedSchema(
schema: StructType,
requestedFields: Seq[StructField],
expectedSchema: StructType): Unit = {
val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(prunedSchema === expectedSchema)
}

test("prune schema by the requested fields") {
def testPrunedSchema(
schema: StructType,
requestedFields: StructField*): Unit = {
val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(expectedSchema == StructType(requestedFields))
}

testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))
testPrunedSchema(
StructType.fromDDL("a int, b int"),
Seq(StructField("a", IntegerType)),
StructType.fromDDL("a int, b int"))

val structOfStruct = StructType.fromDDL("a struct<a:int, b:int>, b int")
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int")))
testPrunedSchema(structOfStruct, StructField("b", IntegerType))
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int")))
testPrunedSchema(structOfStruct,
Seq(StructField("a", StructType.fromDDL("a int")), StructField("b", IntegerType)),
StructType.fromDDL("a struct<a:int>, b int"))
testPrunedSchema(structOfStruct,
Seq(StructField("a", StructType.fromDDL("a int"))),
StructType.fromDDL("a struct<a:int>, b int"))

val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string")))
val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"),
Expand All @@ -60,44 +57,76 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) ::
mapOfStruct :: Nil)

testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))),
StructField("b", StructType.fromDDL("a int")))
testPrunedSchema(complexStruct,
StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
StructField("b", StructType.fromDDL("b int")))
Seq(StructField("a", ArrayType(StructType.fromDDL("b int"))),
StructField("b", StructType.fromDDL("a int"))),
StructType(
StructField("a", ArrayType(StructType.fromDDL("b int"))) ::
StructField("b", StructType.fromDDL("a int")) ::
StructField("c", IntegerType) ::
mapOfStruct :: Nil))
testPrunedSchema(complexStruct,
Seq(StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
StructField("b", StructType.fromDDL("b int"))),
StructType(
StructField("a", ArrayType(StructType.fromDDL("b int, c string"))) ::
StructField("b", StructType.fromDDL("b int")) ::
StructField("c", IntegerType) ::
mapOfStruct :: Nil))

val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"),
StructType.fromDDL("e int, f string")))
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
testPrunedSchema(complexStruct,
Seq(StructField("c", IntegerType), selectFieldInMap),
StructType(
arrayOfStruct ::
StructField("b", structOfStruct) ::
StructField("c", IntegerType) ::
selectFieldInMap :: Nil))
}

test("SPARK-35096: test case insensitivity of pruned schema") {
Seq(true, false).foreach(isCaseSensitive => {
val upperCaseSchema = StructType.fromDDL("A struct<A:int, B:int>, B int")
val lowerCaseSchema = StructType.fromDDL("a struct<a:int, b:int>, b int")
val upperCaseRequestedFields = Seq(StructField("A", StructType.fromDDL("A int")))
val lowerCaseRequestedFields = Seq(StructField("a", StructType.fromDDL("a int")))

Seq(true, false).foreach { isCaseSensitive =>
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
if (isCaseSensitive) {
// Schema is case-sensitive
val requestedFields = getRootFields(StructField("id", IntegerType))
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"), requestedFields)
assert(prunedSchema == StructType(Seq.empty))
// Root fields are case-sensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
testPrunedSchema(
upperCaseSchema,
upperCaseRequestedFields,
StructType.fromDDL("A struct<A:int>, B int"))
testPrunedSchema(
upperCaseSchema,
lowerCaseRequestedFields,
upperCaseSchema)

testPrunedSchema(
lowerCaseSchema,
upperCaseRequestedFields,
lowerCaseSchema)
testPrunedSchema(
lowerCaseSchema,
lowerCaseRequestedFields,
StructType.fromDDL("a struct<a:int>, b int"))
} else {
// Schema is case-insensitive
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("ID int, name String"),
getRootFields(StructField("id", IntegerType)))
assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
// Root fields are case-insensitive
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
StructType.fromDDL("id int, name String"),
getRootFields(StructField("ID", IntegerType)))
assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
testPrunedSchema(
upperCaseSchema,
requestedFields,
StructType.fromDDL("A struct<A:int>, B int"))
}

Seq(upperCaseRequestedFields, lowerCaseRequestedFields).foreach { requestedFields =>
testPrunedSchema(
lowerCaseSchema,
requestedFields,
StructType.fromDDL("a struct<a:int>, b int"))
}
}
}
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import org.apache.spark.sql.types.{StructField, StructType}
case class HadoopFsRelation(
location: FileIndex,
partitionSchema: StructType,
// The top-level columns in `dataSchema` should match the actual physical file schema, otherwise
// the ORC data source may not work with the by-ordinal mode.
dataSchema: StructType,
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ object PushDownUtils extends PredicateHelper {
relation: DataSourceV2Relation,
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Scan, Seq[AttributeReference]) = {
val exprs = projects ++ filters
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)

scanBuilder match {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
Expand All @@ -89,14 +93,12 @@ object PushDownUtils extends PredicateHelper {
} else {
new StructType()
}
r.pruneColumns(prunedSchema)
val neededFieldNames = neededOutput.map(_.name).toSet
r.pruneColumns(StructType(prunedSchema.filter(f => neededFieldNames.contains(f.name))))
val scan = r.build()
scan -> toOutputAttrs(scan.readSchema(), relation)

case r: SupportsPushDownRequiredColumns =>
val exprs = projects ++ filters
val requiredColumns = AttributeSet(exprs.flatMap(_.references))
val neededOutput = relation.output.filter(requiredColumns.contains)
r.pruneColumns(neededOutput.toStructType)
val scan = r.build()
// always project, in case the relation's output has been updated and doesn't match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,4 +633,20 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession {
}
}
}

test("SPARK-34897: Support reconcile schemas based on index after nested column pruning") {
withTable("t1") {
spark.sql(
"""
|CREATE TABLE t1 (
| _col0 INT,
| _col1 STRING,
| _col2 STRUCT<c1: STRING, c2: STRING, c3: STRING, c4: BIGINT>)
|USING ORC
|""".stripMargin)

spark.sql("INSERT INTO t1 values(1, '2', struct('a', 'b', 'c', 10L))")
checkAnswer(spark.sql("SELECT _col0, _col2.c1 FROM t1"), Seq(Row(1, "a")))
}
}
}

0 comments on commit e609395

Please sign in to comment.