Skip to content

Commit

Permalink
Merge pull request apache#41 from palantir/pw/parquetInRewritw
Browse files Browse the repository at this point in the history
SPARK-17091: ParquetFilters rewrite IN to OR of Eq
  • Loading branch information
pwoody authored Sep 27, 2016
2 parents 4e358e9 + edab4ad commit 271c85e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ private[parquet] object ParquetFilters {
case sources.Not(pred) =>
createFilter(schema, pred).map(FilterApi.not)

case sources.In(name, values) if dataTypeOf.contains(name) =>
values.flatMap { v =>
makeEq.lift(dataTypeOf(name)).map(_(name, v))
}.reduceLeftOption(FilterApi.or)

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
* NOTE:
*
* 1. `!(a cmp b)` is always transformed to its negated form `a cmp' b` by the
* `BooleanSimplification` optimization rule whenever possible. As a result, predicate `!(a < 1)`
* results in a `GtEq` filter predicate rather than a `Not`.
* `BooleanSimplification` optimization rule whenever possible. As a result, predicate
* `!(a < 1)` results in a `GtEq` filter predicate rather than a `Not`.
*
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
* data type is nullable.
Expand Down Expand Up @@ -369,7 +369,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") {
import testImplicits._
Seq("true", "false").map { vectorized =>
Seq("true", "false").foreach { vectorized =>
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true",
SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true",
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) {
Expand Down Expand Up @@ -535,25 +535,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
import testImplicits._

withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/table"
(1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)

Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
val accu = new LongAccumulator
accu.register(sparkContext, Some("numRowGroups"))

val df = spark.read.parquet(path).filter("a < 100")
df.foreachPartition(_.foreach(v => accu.add(0)))
df.collect

val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
assert(numRowGroups.isDefined)
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
AccumulatorContext.remove(accu.id)
Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0))
.foreach { case (push, func) =>
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
val accu = new LongAccumulator
accu.register(sparkContext, Some("numRowGroups"))

val df = spark.read.parquet(path).filter("a < 100")
df.foreachPartition(_.foreach(v => accu.add(0)))
df.collect

val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
assert(numRowGroups.isDefined)
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
AccumulatorContext.remove(accu.id)
}
}
}
}
}

test("In filters are pushed down") {
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
withTempPath { dir =>
val path = s"${dir.getCanonicalPath}/table1"
(1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path)
val df = spark.read.parquet(path).where("b in (0,2)")
assert(stripSparkFilter(df).count == 3)
val df1 = spark.read.parquet(path).where("not (b in (1))")
assert(stripSparkFilter(df1).count == 3)
val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)")
assert(stripSparkFilter(df2).count == 2)
val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)")
assert(stripSparkFilter(df3).count == 4)
val df4 = spark.read.parquet(path).where("not (a <= 2)")
assert(stripSparkFilter(df4).count == 3)
}
}
}
Expand Down

0 comments on commit 271c85e

Please sign in to comment.