Skip to content

Commit

Permalink
Implements partitioning support for data sources API
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed May 12, 2015
1 parent 3c5073a commit 327bb1d
Show file tree
Hide file tree
Showing 20 changed files with 936 additions and 134 deletions.
47 changes: 35 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import java.util.{Arrays, Comparator}

import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.primitives.Longs
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.security.{Credentials, UserGroupInformation}

import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.language.postfixOps
import org.apache.spark.{Logging, SparkConf, SparkException}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -199,13 +199,36 @@ class SparkHadoopUtil extends Logging {
* that file.
*/
def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = {
def recurse(path: Path): Array[FileStatus] = {
val (directories, leaves) = fs.listStatus(path).partition(_.isDir)
leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath))
listLeafStatuses(fs, fs.getFileStatus(basePath))
}

/**
* Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the
* given path points to a file, return a single-element collection containing [[FileStatus]] of
* that file.
*/
def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = {
def recurse(status: FileStatus): Seq[FileStatus] = {
val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir)
leaves ++ directories.flatMap(f => listLeafStatuses(fs, f))
}

if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus)
}

def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = {
listLeafDirStatuses(fs, fs.getFileStatus(basePath))
}

def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = {
def recurse(status: FileStatus): Seq[FileStatus] = {
val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir)
val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus]
leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir))
}

val baseStatus = fs.getFileStatus(basePath)
if (baseStatus.isDir) recurse(basePath) else Array(baseStatus)
assert(baseStatus.isDir)
recurse(baseStatus)
}

/**
Expand Down Expand Up @@ -275,7 +298,7 @@ class SparkHadoopUtil extends Logging {
logDebug(text + " matched " + HADOOP_CONF_PATTERN)
val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. }
val eval = Option[String](hadoopConf.get(key))
.map { value =>
.map { value =>
logDebug("Substituted " + matched + " with " + value)
text.replace(matched, value)
}
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test += "-agentlib:jdwp=transport=dt_socket,server=n,address=127.0.0.1:5005",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
javaOptions in Test += "-ea",
Expand Down
21 changes: 11 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import com.fasterxml.jackson.core.JsonFactory

import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JacksonGenerator
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -1372,6 +1372,7 @@ class DataFrame private[sql](
tableName,
source,
temporary = false,
Array.empty[String],
mode,
options,
logicalPlan)
Expand Down Expand Up @@ -1473,7 +1474,7 @@ class DataFrame private[sql](
mode: SaveMode,
options: java.util.Map[String, String],
partitionColumns: java.util.List[String]): Unit = {
???
save(source, mode, options.toMap, partitionColumns)
}

/**
Expand All @@ -1488,7 +1489,7 @@ class DataFrame private[sql](
source: String,
mode: SaveMode,
options: Map[String, String]): Unit = {
ResolvedDataSource(sqlContext, source, mode, options, this)
ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this)
}

/**
Expand All @@ -1503,7 +1504,7 @@ class DataFrame private[sql](
mode: SaveMode,
options: Map[String, String],
partitionColumns: Seq[String]): Unit = {
???
ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this)
}

/**
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ private[spark] object SQLConf {
// to its length exceeds the threshold.
val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold"

// Whether to perform partition discovery when loading external data sources.
val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled"

// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def load(source: String, options: Map[String, String]): DataFrame = {
val resolved = ResolvedDataSource(this, None, source, options)
val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}

Expand Down Expand Up @@ -792,7 +792,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
val resolved = ResolvedDataSource(this, Some(schema), source, options)
val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")

case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) =>
val cmd =
CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query)
if partitionsCols.nonEmpty =>
sys.error("Cannot create temporary partitioned table.")

case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) =>
val cmd = CreateTempTableUsingAsSelect(
tableName, provider, Array.empty[String], mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ private[sql] class DefaultSource
}
}

private[sql] case class Partition(values: Row, path: String)

private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])

/**
* An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is
* intended as a full replacement of the Parquet support in Spark SQL. The old implementation will
Expand Down Expand Up @@ -805,7 +801,7 @@ private[sql] object ParquetRelation2 extends Logging {
val ordinalMap = metastoreSchema.zipWithIndex.map {
case (field, index) => field.name.toLowerCase -> index
}.toMap
val reorderedParquetSchema = mergedParquetSchema.sortBy(f =>
val reorderedParquetSchema = mergedParquetSchema.sortBy(f =>
ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1))

StructType(metastoreSchema.zip(reorderedParquetSchema).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ private[sql] object DataSourceStrategy extends Strategy {
filters,
(a, _) => t.buildScan(a)) :: Nil

case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) =>
val selectedPartitions = prunePartitions(filters, t.partitionSpec)
val inputPaths = selectedPartitions.map(_.path).toArray

// Don't push down predicates that reference partition columns
val pushedFilters = {
val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet
filters.filter { f =>
val referencedColumnNames = f.references.map(_.name).toSet
referencedColumnNames.intersect(partitionColumnNames).isEmpty
}
}

pruneFilterProject(
l,
projectList,
pushedFilters,
(a, f) => t.buildScan(a, f, inputPaths)) :: Nil

case l @ LogicalRelation(t: TableScan) =>
createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil

Expand All @@ -63,6 +82,33 @@ private[sql] object DataSourceStrategy extends Strategy {
case _ => Nil
}

protected def prunePartitions(
predicates: Seq[Expression],
partitionSpec: PartitionSpec): Seq[Partition] = {
val PartitionSpec(partitionColumns, partitions) = partitionSpec
val partitionColumnNames = partitionColumns.map(_.name).toSet
val partitionPruningPredicates = predicates.filter {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}

if (partitionPruningPredicates.nonEmpty) {
val predicate =
partitionPruningPredicates
.reduceOption(expressions.And)
.getOrElse(Literal(true))

val boundPredicate = InterpretedPredicate(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
})

partitions.filter { case Partition(values, _) => boundPredicate(values) }
} else {
partitions
}
}

// Based on Public API.
protected def pruneFilterProject(
relation: LogicalRelation,
Expand Down
Loading

0 comments on commit 327bb1d

Please sign in to comment.