From edab4ad607946b5773f2807664d27a9a5fc6f475 Mon Sep 17 00:00:00 2001 From: Andrew Duffy Date: Tue, 16 Aug 2016 18:57:15 +0100 Subject: [PATCH] SPARK-17091: ParquetFilters rewrite IN to OR of Eq --- .../datasources/parquet/ParquetFilters.scala | 5 ++ .../parquet/ParquetFilterSuite.scala | 57 +++++++++++++------ 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index a6e9788097728..1db46e4122525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -234,6 +234,11 @@ private[parquet] object ParquetFilters { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if dataTypeOf.contains(name) => + values.flatMap { v => + makeEq.lift(dataTypeOf(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4246b54c21f0c..c6f691f7f55ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -40,8 +40,8 @@ import org.apache.spark.util.{AccumulatorContext, LongAccumulator} * NOTE: * * 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the - * `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)` - * results in a `GtEq` filter predicate rather than a `Not`. + * `BooleanSimplification` optimization rule whenever possible. As a result, predicate + * `!(a < 1)` results in a `GtEq` filter predicate rather than a `Not`. * * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. @@ -369,7 +369,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ - Seq("true", "false").map { vectorized => + Seq("true", "false").foreach { vectorized => withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { @@ -535,25 +535,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex import testImplicits._ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", - SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) - - val df = spark.read.parquet(path).filter("a < 100") - df.foreachPartition(_.foreach(v => accu.add(0))) - df.collect - - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) - AccumulatorContext.remove(accu.id) + Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)) + .foreach { case (push, func) => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { + val accu = new LongAccumulator + accu.register(sparkContext, Some("numRowGroups")) + + val df = spark.read.parquet(path).filter("a < 100") + df.foreachPartition(_.foreach(v => accu.add(0))) + df.collect + + val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") + assert(numRowGroups.isDefined) + assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + AccumulatorContext.remove(accu.id) + } } + } + } + } + + test("In filters are pushed down") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) + val df = spark.read.parquet(path).where("b in (0,2)") + assert(stripSparkFilter(df).count == 3) + val df1 = spark.read.parquet(path).where("not (b in (1))") + assert(stripSparkFilter(df1).count == 3) + val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)") + assert(stripSparkFilter(df2).count == 2) + val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)") + assert(stripSparkFilter(df3).count == 4) + val df4 = spark.read.parquet(path).where("not (a <= 2)") + assert(stripSparkFilter(df4).count == 3) } } }