From 8d9495a8f1e64dbc42c3741f9bcbd4893ce3f0e9 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 31 Aug 2018 19:24:09 +0800 Subject: [PATCH] [SPARK-25207][SQL] Case-insensitve field resolution for filter pushdown when reading Parquet ## What changes were proposed in this pull request? Currently, filter pushdown will not work if Parquet schema and Hive metastore schema are in different letter cases even spark.sql.caseSensitive is false. Like the below case: ```scala spark.sparkContext.hadoopConfiguration.setInt("parquet.block.size", 8 * 1024 * 1024) spark.range(1, 40 * 1024 * 1024, 1, 1).sortWithinPartitions("id").write.parquet("/tmp/t") sql("CREATE TABLE t (ID LONG) USING parquet LOCATION '/tmp/t'") sql("select * from t where id < 100L").write.csv("/tmp/id") ``` Although filter "ID < 100L" is generated by Spark, it fails to pushdown into parquet actually, Spark still does the full table scan when reading. This PR provides a case-insensitive field resolution to make it work. Before - "ID < 100L" fail to pushedown: screen shot 2018-08-23 at 10 08 26 pm After - "ID < 100L" pushedown sucessfully: screen shot 2018-08-23 at 10 08 40 pm ## How was this patch tested? Added UTs. Closes #22197 from yucai/SPARK-25207. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 90 ++++++++++---- .../parquet/ParquetFilterSuite.scala | 115 +++++++++++++++++- 3 files changed, 179 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d7eb14356b8b1..ea4f1592a7c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -347,6 +347,7 @@ class ParquetFileFormat val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -372,7 +373,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, - pushDownStringStartWith, pushDownInFilterThreshold) + pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 58b4a769fcb62..0c286defb9406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} +import java.util.Locale import scala.collection.JavaConverters.asScalaBufferConverter @@ -31,7 +32,7 @@ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +45,18 @@ private[parquet] class ParquetFilters( pushDownTimestamp: Boolean, pushDownDecimal: Boolean, pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int) { + pushDownInFilterThreshold: Int, + caseSensitive: Boolean) { + + /** + * Holds a single field information stored in the underlying parquet file. + * + * @param fieldName field name in parquet file + * @param fieldType field type related info in parquet file + */ + private case class ParquetField( + fieldName: String, + fieldType: ParquetSchemaType) private case class ParquetSchemaType( originalType: OriginalType, @@ -350,25 +362,38 @@ private[parquet] class ParquetFilters( } /** - * Returns a map from name of the column to the data type, if predicate push down applies. + * Returns a map, which contains parquet field name and data type, if predicate push down applies. */ - private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { - case m: MessageType => - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) - }.toMap - case _ => Map.empty[String, ParquetSchemaType] + private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = { + // Here we don't flatten the fields in the nested schema but just look up through + // root fields. Currently, accessing to nested fields does not push down filters + // and it does not support to create filters for them. + val primitiveFields = + dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetField(f.getName, + ParquetSchemaType(f.getOriginalType, + f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + } + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25132. + val dedupPrimitiveFields = + primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } } /** * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { - val nameToType = getFieldMap(schema) + val nameToParquetField = getFieldMap(schema) // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. @@ -381,7 +406,7 @@ private[parquet] class ParquetFilters( // Parquet's type in the given file should be matched to the value's type // in the pushed filter in order to push down the filter to Parquet. def valueCanMakeFilterOn(name: String, value: Any): Boolean = { - value == null || (nameToType(name) match { + value == null || (nameToParquetField(name).fieldType match { case ParquetBooleanType => value.isInstanceOf[JBoolean] case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] case ParquetLongType => value.isInstanceOf[JLong] @@ -408,7 +433,7 @@ private[parquet] class ParquetFilters( // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -428,29 +453,39 @@ private[parquet] class ParquetFilters( predicate match { case sources.IsNull(name) if canMakeFilterOn(name, null) => - makeEq.lift(nameToType(name)).map(_(name, null)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.IsNotNull(name) if canMakeFilterOn(name, null) => - makeNotEq.lift(nameToType(name)).map(_(name, null)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, null)) case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => - makeEq.lift(nameToType(name)).map(_(name, value)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => - makeNotEq.lift(nameToType(name)).map(_(name, value)) + makeNotEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThan(name, value) if canMakeFilterOn(name, value) => - makeLt.lift(nameToType(name)).map(_(name, value)) + makeLt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeLtEq.lift(nameToType(name)).map(_(name, value)) + makeLtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => - makeGt.lift(nameToType(name)).map(_(name, value)) + makeGt.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => - makeGtEq.lift(nameToType(name)).map(_(name, value)) + makeGtEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the @@ -477,7 +512,8 @@ private[parquet] class ParquetFilters( case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => - makeEq.lift(nameToType(name)).map(_(name, v)) + makeEq.lift(nameToParquetField(name).fieldType) + .map(_(nameToParquetField(name).fieldName, v)) }.reduceLeftOption(FilterApi.or) case sources.StringStartsWith(name, prefix) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index be4f498c921ab..7ebb75009555a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -25,6 +25,7 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operato import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -60,7 +61,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, - conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownInFilterThreshold, conf.caseSensitiveAnalysis) override def beforeEach(): Unit = { super.beforeEach() @@ -1021,6 +1022,118 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-25207: Case-insensitive field resolution for pushdown when reading parquet") { + def createParquetFilter(caseSensitive: Boolean): ParquetFilters = { + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold, caseSensitive) + } + val caseSensitiveParquetFilters = createParquetFilter(caseSensitive = true) + val caseInsensitiveParquetFilters = createParquetFilter(caseSensitive = false) + + def testCaseInsensitiveResolution( + schema: StructType, + expected: FilterPredicate, + filter: sources.Filter): Unit = { + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(expected)) { + caseInsensitiveParquetFilters.createFilter(parquetSchema, filter) + } + assertResult(None) { + caseSensitiveParquetFilters.createFilter(parquetSchema, filter) + } + } + + val schema = StructType(Seq(StructField("cint", IntegerType))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), null.asInstanceOf[Integer]), sources.IsNull("CINT")) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), null.asInstanceOf[Integer]), + sources.IsNotNull("CINT")) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualTo("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualTo("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, FilterApi.eq(intColumn("cint"), 1000: Integer), sources.EqualNullSafe("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.notEq(intColumn("cint"), 1000: Integer), + sources.Not(sources.EqualNullSafe("CINT", 1000))) + + testCaseInsensitiveResolution( + schema, + FilterApi.lt(intColumn("cint"), 1000: Integer), sources.LessThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.ltEq(intColumn("cint"), 1000: Integer), + sources.LessThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, FilterApi.gt(intColumn("cint"), 1000: Integer), sources.GreaterThan("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.gtEq(intColumn("cint"), 1000: Integer), + sources.GreaterThanOrEqual("CINT", 1000)) + + testCaseInsensitiveResolution( + schema, + FilterApi.or( + FilterApi.eq(intColumn("cint"), 10: Integer), + FilterApi.eq(intColumn("cint"), 20: Integer)), + sources.In("CINT", Array(10, 20))) + + val dupFieldSchema = StructType( + Seq(StructField("cint", IntegerType), StructField("cINT", IntegerType))) + val dupParquetSchema = new SparkToParquetSchemaConverter(conf).convert(dupFieldSchema) + assertResult(None) { + caseInsensitiveParquetFilters.createFilter( + dupParquetSchema, sources.EqualTo("CINT", 1000)) + } + } + + test("SPARK-25207: exception when duplicate fields in case-insensitive mode") { + withTempPath { dir => + val count = 10 + val tableName = "spark_25207" + val tableDir = dir.getAbsoluteFile + "/table" + withTable(tableName) { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark.range(count).selectExpr("id as A", "id as B", "id as b") + .write.mode("overwrite").parquet(tableDir) + } + sql( + s""" + |CREATE TABLE $tableName (A LONG, B LONG) USING PARQUET LOCATION '$tableDir' + """.stripMargin) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val e = intercept[SparkException] { + sql(s"select a from $tableName where b > 0").collect() + } + assert(e.getCause.isInstanceOf[RuntimeException] && e.getCause.getMessage.contains( + """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer(sql(s"select A from $tableName where B > 0"), (1 until count).map(Row(_))) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {