From 502fa8070c654c8c21d5491f9c1b8ad0aabbd666 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 5 Mar 2024 20:58:05 +0800 Subject: [PATCH] [SPARK-47168][SQL] Disable parquet filter pushdown when working with non default collated strings ### What changes were proposed in this pull request? Disable parquet filter pushdown when expression is referencing a non default collated column/field in a struct. ### Why are the changes needed? Because parquet min/max stats don't know about the concept of collation and data skipping based on them could lead to incorrect results for certain collation types (lowercase collation for example). ### Does this PR introduce _any_ user-facing change? Users should now always get the correct result when using collated string columns. ### How was this patch tested? With new UTs in `ParquetFilterSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45262 from stefankandic/disableFilterPushdown. Lead-authored-by: Stefan Kandic Co-authored-by: Stefan Kandic <154237371+stefankandic@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../apache/spark/sql/util/SchemaUtils.scala | 12 ++++- .../datasources/DataSourceUtils.scala | 21 +++++++- .../datasources/FileSourceStrategy.scala | 4 +- .../PruneFileSourcePartitions.scala | 3 +- .../datasources/v2/FileScanBuilder.scala | 6 +-- .../sql-tests/results/collations.sql.out | 1 + .../spark/sql/FileBasedDataSourceSuite.scala | 48 ++++++++++++++++++- 7 files changed, 86 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index d061fde27c86b..db547baf84d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkSchemaUtils @@ -293,4 +293,14 @@ private[spark] object SchemaUtils { * @return The escaped string. */ def escapeMetaCharacters(str: String): String = SparkSchemaUtils.escapeMetaCharacters(str) + + /** + * Checks if a given data type has a non-default collation string type. + */ + def hasNonDefaultCollatedString(dt: DataType): Boolean = { + dt.existsRecursively { + case st: StringType => !st.isDefaultCollation + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index cf02826baf3b4..5d7cda57b15b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,7 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.{SparkException, SparkUpgradeException} import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} import org.apache.spark.util.Utils @@ -279,4 +279,21 @@ object DataSourceUtils extends PredicateHelper { dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) } + + /** + * Determines whether a filter should be pushed down to the data source or not. + * + * @param expression The filter expression to be evaluated. + * @return A boolean indicating whether the filter should be pushed down or not. + */ + def shouldPushFilter(expression: Expression): Boolean = { + expression.deterministic && !expression.exists { + case childExpression @ (_: Attribute | _: GetStructField) => + // don't push down filters for types with non-default collation + // as it could lead to incorrect results + SchemaUtils.hasNonDefaultCollatedString(childExpression.dataType) + + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 6e5c463ed72b9..e4b66d72eaf85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -160,8 +160,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) + val filtersToPush = filters.filter(f => DataSourceUtils.shouldPushFilter(f)) + val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(_.deterministic), l.output) + filtersToPush, l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 1dffea4e1bc87..408da5dad7684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,7 +63,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), + filters.filter(f => + !SubqueryExpression.hasSubquery(f) && DataSourceUtils.shouldPushFilter(f)), logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 447a36fe622c9..346bff980a965 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -70,9 +70,9 @@ abstract class FileScanBuilder( } override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { - val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic) + val (filtersToPush, filtersToRemain) = filters.partition(DataSourceUtils.shouldPushFilter) val (partitionFilters, dataFilters) = - DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters) + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush) this.partitionFilters = partitionFilters this.dataFilters = dataFilters val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] @@ -83,7 +83,7 @@ abstract class FileScanBuilder( } } pushedDataFilters = pushDataFilters(translatedFilters.toArray) - dataFilters ++ nonDeterminsticFilters + dataFilters ++ filtersToRemain } override def pushedFilters: Array[Predicate] = pushedDataFilters.map(_.toV2) diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index 630b18674e654..dbb7eafe71083 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -72,6 +72,7 @@ select * from t1 where ucs_basic_lcase = 'aaa' collate 'ucs_basic_lcase' -- !query schema struct -- !query output +AAA AAA aaa aaa diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 36af4121f30b0..47e1f37d4be52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} +import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} @@ -1246,6 +1246,52 @@ class FileBasedDataSourceSuite extends QueryTest } } } + + test("disable filter pushdown for collated strings") { + Seq("parquet").foreach { format => + withTempPath { path => + val collation = "'UCS_BASIC_LCASE'" + val df = sql( + s"""SELECT + | COLLATE(c, $collation) as c1, + | struct(COLLATE(c, $collation)) as str, + | named_struct('f1', named_struct('f2', COLLATE(c, $collation), 'f3', 1)) as namedstr, + | array(COLLATE(c, $collation)) as arr, + | map(COLLATE(c, $collation), 1) as map1, + | map(1, COLLATE(c, $collation)) as map2 + |FROM VALUES ('aaa'), ('AAA'), ('bbb') + |as data(c) + |""".stripMargin) + + df.write.format(format).save(path.getAbsolutePath) + + // filter and expected result + val filters = Seq( + ("==", Seq(Row("aaa"), Row("AAA"))), + ("!=", Seq(Row("bbb"))), + ("<", Seq()), + ("<=", Seq(Row("aaa"), Row("AAA"))), + (">", Seq(Row("bbb"))), + (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) + + filters.foreach { filter => + val readback = spark.read + .parquet(path.getAbsolutePath) + .where(s"c1 ${filter._1} collate('aaa', $collation)") + .where(s"str ${filter._1} struct(collate('aaa', $collation))") + .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)") + .where(s"arr ${filter._1} array(collate('aaa', $collation))") + .where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))") + .where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))") + .select("c1") + + val explain = readback.queryExecution.explainString(ExplainMode.fromString("extended")) + assert(explain.contains("PushedFilters: []")) + checkAnswer(readback, filter._2) + } + } + } + } } object TestingUDT {