Skip to content

Commit

Permalink
[SPARK-25557][SQL] Nested column predicate pushdown for ORC
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

We added nested column predicate pushdown for Parquet in #27728. This patch extends the feature support to ORC.

### Why are the changes needed?

Extending the feature to ORC for feature parity. Better performance for handling nested predicate pushdown.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests.

Closes #28761 from viirya/SPARK-25557.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
viirya authored and dongjoon-hyun committed Aug 7, 2020
1 parent 6c3d0a4 commit 7b6e1d5
Show file tree
Hide file tree
Showing 11 changed files with 460 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2108,9 +2108,9 @@ object SQLConf {
.doc("A comma-separated list of data source short names or fully qualified data source " +
"implementation class names for which Spark tries to push down predicates for nested " +
"columns and/or names containing `dots` to data sources. This configuration is only " +
"effective with file-based data source in DSv1. Currently, Parquet implements " +
"both optimizations while ORC only supports predicates for names containing `dots`. The " +
"other data sources don't support this feature yet. So the default value is 'parquet,orc'.")
"effective with file-based data sources in DSv1. Currently, Parquet and ORC implement " +
"both optimizations. The other data sources don't support this feature yet. So the " +
"default value is 'parquet,orc'.")
.version("3.0.0")
.stringConf
.createWithDefault("parquet,orc")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ abstract class PushableColumnBase {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
def helper(e: Expression): Option[Seq[String]] = e match {
case a: Attribute =>
// Attribute that contains dot "." in name is supported only when
// nested predicate pushdown is enabled.
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
Some(Seq(a.name))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.execution.datasources.orc

import java.util.Locale

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType}
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType}

/**
* Methods that can be shared when upgrading the built-in Hive.
Expand All @@ -37,12 +40,45 @@ trait OrcFiltersBase {
}

/**
* Return true if this is a searchable type in ORC.
* Both CharType and VarcharType are cleaned at AstBuilder.
* This method returns a map which contains ORC field name and data type. Each key
* represents a column; `dots` are used as separators for nested columns. If any part
* of the names contains `dots`, it is quoted to avoid confusion. See
* `org.apache.spark.sql.connector.catalog.quoted` for implementation details.
*
* BinaryType, UserDefinedType, ArrayType and MapType are ignored.
*/
protected[sql] def isSearchableType(dataType: DataType) = dataType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
protected[sql] def getSearchableTypeMap(
schema: StructType,
caseSensitive: Boolean): Map[String, DataType] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper

def getPrimitiveFields(
fields: Seq[StructField],
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
fields.flatMap { f =>
f.dataType match {
case st: StructType =>
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
case BinaryType => None
case _: AtomicType =>
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
case _ => None
}
}
}

val primitiveFields = getPrimitiveFields(schema.fields)
if (caseSensitive) {
primitiveFields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field are matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25175.
val dedupPrimitiveFields = primitiveFields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.mapValues(_.head._2)
CaseInsensitiveMap(dedupPrimitiveFields)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -60,10 +61,8 @@ case class OrcScanBuilder(
// changed `hadoopConf` in executors.
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
}
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
}
filters
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

/**
* A helper trait that provides convenient facilities for file-based data source testing.
Expand Down Expand Up @@ -103,4 +105,40 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
df: DataFrame, path: File): Unit = {
df.write.mode(SaveMode.Overwrite).format(dataSourceName).save(path.getCanonicalPath)
}

/**
* Takes single level `inputDF` dataframe to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with specified datasource.
*/
protected def withNestedDataFrame(inputDF: DataFrame): Seq[(DataFrame, String, Any => Any)] = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,26 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
FileUtils.copyURLToFile(url, file)
spark.read.orc(file.getAbsolutePath)
}

/**
* Takes a sequence of products `data` to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with Orc datasource.
*
* This is different from [[withOrcDataFrame]] which does not
* test nested cases.
*/
protected def withNestedOrcDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T])
(runTest: (DataFrame, String, Any => Any) => Unit): Unit =
withNestedOrcDataFrame(spark.createDataFrame(data))(runTest)

protected def withNestedOrcDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readFile(file.getCanonicalPath, true) { df => runTest(df, colName, resultFun) }
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared

private def withNestedParquetDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
).foreach { case (newDF, colName, resultFun) =>
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -68,11 +68,9 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// Combines all convertible filters using `And` to produce a single conjunction
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters))
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters))
conjunctionOptional.map { conjunction =>
// Then tries to build a single ORC `SearchArgument` for the conjunction predicate.
// The input predicate is fully convertible. There should not be any empty result in the
Expand Down Expand Up @@ -228,40 +226,38 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
// call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
case EqualTo(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().equals(name, getType(name), castedValue).end())

case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())

case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
case LessThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())

case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
case GreaterThan(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())

case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
val castedValue = castLiteralValue(value, dataTypeMap(name))
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())

case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
case IsNull(name) if dataTypeMap.contains(name) =>
Some(builder.startAnd().isNull(name, getType(name)).end())

case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
case IsNotNull(name) if dataTypeMap.contains(name) =>
Some(builder.startNot().isNull(name, getType(name)).end())

case In(name, values) if isSearchableType(dataTypeMap(name)) =>
case In(name, values) if dataTypeMap.contains(name) =>
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
Some(builder.startAnd().in(name, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
Expand Down
Loading

0 comments on commit 7b6e1d5

Please sign in to comment.