Skip to content

Commit

Permalink
[SPARK-25207][SQL] Case-insensitve field resolution for filter pushdo…
Browse files Browse the repository at this point in the history
…wn when reading Parquet

## What changes were proposed in this pull request?

Currently, filter pushdown will not work if Parquet schema and Hive metastore schema are in different letter cases even spark.sql.caseSensitive is false.

Like the below case:
```scala
spark.sparkContext.hadoopConfiguration.setInt("parquet.block.size", 8 * 1024 * 1024)
spark.range(1, 40 * 1024 * 1024, 1, 1).sortWithinPartitions("id").write.parquet("/tmp/t")
sql("CREATE TABLE t (ID LONG) USING parquet LOCATION '/tmp/t'")
sql("select * from t where id < 100L").write.csv("/tmp/id")
```

Although filter "ID < 100L" is generated by Spark, it fails to pushdown into parquet actually, Spark still does the full table scan when reading.
This PR provides a case-insensitive field resolution to make it work.

Before - "ID < 100L" fail to pushedown:
<img width="273" alt="screen shot 2018-08-23 at 10 08 26 pm" src="https://user-images.githubusercontent.com/2989575/44530558-40ef8b00-a721-11e8-8abc-7f97671590d3.png">
After - "ID < 100L" pushedown sucessfully:
<img width="267" alt="screen shot 2018-08-23 at 10 08 40 pm" src="https://user-images.githubusercontent.com/2989575/44530567-44831200-a721-11e8-8634-e9f664b33d39.png">

## How was this patch tested?

Added UTs.

Closes #22197 from yucai/SPARK-25207.

Authored-by: yucai <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
yucai authored and cloud-fan committed Aug 31, 2018
1 parent 515708d commit 8d9495a
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ class ParquetFileFormat
val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
val isCaseSensitive = sqlConf.caseSensitiveAnalysis

(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
Expand All @@ -372,7 +373,7 @@ class ParquetFileFormat
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal,
pushDownStringStartWith, pushDownInFilterThreshold)
pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.Locale

import scala.collection.JavaConverters.asScalaBufferConverter

Expand All @@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
import org.apache.spark.sql.sources
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -44,7 +45,18 @@ private[parquet] class ParquetFilters(
pushDownTimestamp: Boolean,
pushDownDecimal: Boolean,
pushDownStartWith: Boolean,
pushDownInFilterThreshold: Int) {
pushDownInFilterThreshold: Int,
caseSensitive: Boolean) {

/**
* Holds a single field information stored in the underlying parquet file.
*
* @param fieldName field name in parquet file
* @param fieldType field type related info in parquet file
*/
private case class ParquetField(
fieldName: String,
fieldType: ParquetSchemaType)

private case class ParquetSchemaType(
originalType: OriginalType,
Expand Down Expand Up @@ -350,25 +362,38 @@ private[parquet] class ParquetFilters(
}

/**
* Returns a map from name of the column to the data type, if predicate push down applies.
* Returns a map, which contains parquet field name and data type, if predicate push down applies.
*/
private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match {
case m: MessageType =>
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetSchemaType(
f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)
}.toMap
case _ => Map.empty[String, ParquetSchemaType]
private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = {
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
val primitiveFields =
dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetField(f.getName,
ParquetSchemaType(f.getOriginalType,
f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
}
if (caseSensitive) {
primitiveFields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25132.
val dedupPrimitiveFields =
primitiveFields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.mapValues(_.head._2)
CaseInsensitiveMap(dedupPrimitiveFields)
}
}

/**
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = {
val nameToType = getFieldMap(schema)
val nameToParquetField = getFieldMap(schema)

// Decimal type must make sure that filter value's scale matched the file.
// If doesn't matched, which would cause data corruption.
Expand All @@ -381,7 +406,7 @@ private[parquet] class ParquetFilters(
// Parquet's type in the given file should be matched to the value's type
// in the pushed filter in order to push down the filter to Parquet.
def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
value == null || (nameToType(name) match {
value == null || (nameToParquetField(name).fieldType match {
case ParquetBooleanType => value.isInstanceOf[JBoolean]
case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
case ParquetLongType => value.isInstanceOf[JLong]
Expand All @@ -408,7 +433,7 @@ private[parquet] class ParquetFilters(
// filters for the column having dots in the names. Thus, we do not push down such filters.
// See SPARK-20364.
def canMakeFilterOn(name: String, value: Any): Boolean = {
nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
}

// NOTE:
Expand All @@ -428,29 +453,39 @@ private[parquet] class ParquetFilters(

predicate match {
case sources.IsNull(name) if canMakeFilterOn(name, null) =>
makeEq.lift(nameToType(name)).map(_(name, null))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, null))
case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
makeNotEq.lift(nameToType(name)).map(_(name, null))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, null))

case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToType(name)).map(_(name, value))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToType(name)).map(_(name, value))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
makeEq.lift(nameToType(name)).map(_(name, value))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
makeNotEq.lift(nameToType(name)).map(_(name, value))
makeNotEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
makeLt.lift(nameToType(name)).map(_(name, value))
makeLt.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeLtEq.lift(nameToType(name)).map(_(name, value))
makeLtEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
makeGt.lift(nameToType(name)).map(_(name, value))
makeGt.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))
case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
makeGtEq.lift(nameToType(name)).map(_(name, value))
makeGtEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, value))

case sources.And(lhs, rhs) =>
// At here, it is not safe to just convert one side if we do not understand the
Expand All @@ -477,7 +512,8 @@ private[parquet] class ParquetFilters(
case sources.In(name, values) if canMakeFilterOn(name, values.head)
&& values.distinct.length <= pushDownInFilterThreshold =>
values.distinct.flatMap { v =>
makeEq.lift(nameToType(name)).map(_(name, v))
makeEq.lift(nameToParquetField(name).fieldType)
.map(_(nameToParquetField(name).fieldName, v))
}.reduceLeftOption(FilterApi.or)

case sources.StringStartsWith(name, prefix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operato
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column => _, _}

import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
private lazy val parquetFilters =
new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp,
conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith,
conf.parquetFilterPushDownInFilterThreshold)
conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis)

override def beforeEach(): Unit = {
super.beforeEach()
Expand Down Expand Up @@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
}

test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") {
def createParquetFilter(caseSensitive: Boolean): ParquetFilters = {
new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp,
conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith,
conf.parquetFilterPushDownInFilterThreshold, caseSensitive)
}
val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true)
val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false)

def testCaseInsensitiveResolution(
schema: StructType,
expected: FilterPredicate,
filter: sources.Filter): Unit = {
val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema)

assertResult(Some(expected)) {
caseInsensitiveParquetFilters.createFilter(parquetSchema, filter)
}
assertResult(None) {
caseSensitiveParquetFilters.createFilter(parquetSchema, filter)
}
}

val schema = StructType(Seq(StructField("cint", IntegerType)))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT"))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]),
sources.IsNotNull("CINT"))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), 1000: Integer),
sources.Not(sources.EqualTo("CINT", 1000)))

testCaseInsensitiveResolution(
schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.notEq(intColumn("cint"), 1000: Integer),
sources.Not(sources.EqualNullSafe("CINT", 1000)))

testCaseInsensitiveResolution(
schema,
FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.ltEq(intColumn("cint"), 1000: Integer),
sources.LessThanOrEqual("CINT", 1000))

testCaseInsensitiveResolution(
schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.gtEq(intColumn("cint"), 1000: Integer),
sources.GreaterThanOrEqual("CINT", 1000))

testCaseInsensitiveResolution(
schema,
FilterApi.or(
FilterApi.eq(intColumn("cint"), 10: Integer),
FilterApi.eq(intColumn("cint"), 20: Integer)),
sources.In("CINT", Array(10, 20)))

val dupFieldSchema = StructType(
Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType)))
val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema)
assertResult(None) {
caseInsensitiveParquetFilters.createFilter(
dupParquetSchema, sources.EqualTo("CINT", 1000))
}
}

test("SPARK-25207: exception when duplicate fields in case-insensitive mode") {
withTempPath { dir =>
val count = 10
val tableName = "spark_25207"
val tableDir = dir.getAbsoluteFile + "/table"
withTable(tableName) {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
spark.range(count).selectExpr("id as A", "id as B", "id as b")
.write.mode("overwrite").parquet(tableDir)
}
sql(
s"""
|CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir'
""".stripMargin)

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
val e = intercept[SparkException] {
sql(s"select a from $tableName where b > 0").collect()
}
assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains(
"""Found duplicate field(s) "B": [B, b] in case-insensitive mode"""))
}

withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_)))
}
}
}
}
}

class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
Expand Down

0 comments on commit 8d9495a

Please sign in to comment.