Skip to content

Commit

Permalink
[SPARK-47168][SQL] Disable parquet filter pushdown when working with …
Browse files Browse the repository at this point in the history
…non default collated strings

### What changes were proposed in this pull request?

Disable parquet filter pushdown when expression is referencing a non default collated column/field in a struct.

### Why are the changes needed?

Because parquet min/max stats don't know about the concept of collation and data skipping based on them could lead to incorrect results for certain collation types (lowercase collation for example).

### Does this PR introduce _any_ user-facing change?

Users should now always get the correct result when using collated string columns.

### How was this patch tested?

With new UTs in `ParquetFilterSuite`

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45262 from stefankandic/disableFilterPushdown.

Lead-authored-by: Stefan Kandic <[email protected]>
Co-authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Mar 5, 2024
1 parent 6534a33 commit 502fa80
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkSchemaUtils

Expand Down Expand Up @@ -293,4 +293,14 @@ private[spark] object SchemaUtils {
* @return The escaped string.
*/
def escapeMetaCharacters(str: String): String = SparkSchemaUtils.escapeMetaCharacters(str)

/**
* Checks if a given data type has a non-default collation string type.
*/
def hasNonDefaultCollatedString(dt: DataType): Boolean = {
dt.existsRecursively {
case st: StringType => !st.isDefaultCollation
case _ => false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ import org.json4s.jackson.Serialization
import org.apache.spark.{SparkException, SparkUpgradeException}
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -279,4 +279,21 @@ object DataSourceUtils extends PredicateHelper {
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters)
}

/**
* Determines whether a filter should be pushed down to the data source or not.
*
* @param expression The filter expression to be evaluated.
* @return A boolean indicating whether the filter should be pushed down or not.
*/
def shouldPushFilter(expression: Expression): Boolean = {
expression.deterministic && !expression.exists {
case childExpression @ (_: Attribute | _: GetStructField) =>
// don't push down filters for types with non-default collation
// as it could lead to incorrect results
SchemaUtils.hasNonDefaultCollatedString(childExpression.dataType)

case _ => false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)

val filtersToPush = filters.filter(f => DataSourceUtils.shouldPushFilter(f))

val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(_.deterministic), l.output)
filtersToPush, l.output)

val partitionColumns =
l.resolve(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_))
if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
filters.filter(f =>
!SubqueryExpression.hasSubquery(f) && DataSourceUtils.shouldPushFilter(f)),
logicalRelation.output)
val (partitionKeyFilters, _) = DataSourceUtils
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ abstract class FileScanBuilder(
}

override def pushFilters(filters: Seq[Expression]): Seq[Expression] = {
val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic)
val (filtersToPush, filtersToRemain) = filters.partition(DataSourceUtils.shouldPushFilter)
val (partitionFilters, dataFilters) =
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters)
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush)
this.partitionFilters = partitionFilters
this.dataFilters = dataFilters
val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
Expand All @@ -83,7 +83,7 @@ abstract class FileScanBuilder(
}
}
pushedDataFilters = pushDataFilters(translatedFilters.toArray)
dataFilters ++ nonDeterminsticFilters
dataFilters ++ filtersToRemain
}

override def pushedFilters: Array[Predicate] = pushedDataFilters.map(_.toV2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ select * from t1 where ucs_basic_lcase = 'aaa' collate 'ucs_basic_lcase'
-- !query schema
struct<ucs_basic:string,ucs_basic_lcase:string COLLATE 'UCS_BASIC_LCASE'>
-- !query output
AAA AAA
aaa aaa


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha
import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
Expand Down Expand Up @@ -1246,6 +1246,52 @@ class FileBasedDataSourceSuite extends QueryTest
}
}
}

test("disable filter pushdown for collated strings") {
Seq("parquet").foreach { format =>
withTempPath { path =>
val collation = "'UCS_BASIC_LCASE'"
val df = sql(
s"""SELECT
| COLLATE(c, $collation) as c1,
| struct(COLLATE(c, $collation)) as str,
| named_struct('f1', named_struct('f2', COLLATE(c, $collation), 'f3', 1)) as namedstr,
| array(COLLATE(c, $collation)) as arr,
| map(COLLATE(c, $collation), 1) as map1,
| map(1, COLLATE(c, $collation)) as map2
|FROM VALUES ('aaa'), ('AAA'), ('bbb')
|as data(c)
|""".stripMargin)

df.write.format(format).save(path.getAbsolutePath)

// filter and expected result
val filters = Seq(
("==", Seq(Row("aaa"), Row("AAA"))),
("!=", Seq(Row("bbb"))),
("<", Seq()),
("<=", Seq(Row("aaa"), Row("AAA"))),
(">", Seq(Row("bbb"))),
(">=", Seq(Row("aaa"), Row("AAA"), Row("bbb"))))

filters.foreach { filter =>
val readback = spark.read
.parquet(path.getAbsolutePath)
.where(s"c1 ${filter._1} collate('aaa', $collation)")
.where(s"str ${filter._1} struct(collate('aaa', $collation))")
.where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)")
.where(s"arr ${filter._1} array(collate('aaa', $collation))")
.where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))")
.where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))")
.select("c1")

val explain = readback.queryExecution.explainString(ExplainMode.fromString("extended"))
assert(explain.contains("PushedFilters: []"))
checkAnswer(readback, filter._2)
}
}
}
}
}

object TestingUDT {
Expand Down

0 comments on commit 502fa80

Please sign in to comment.