Skip to content

Commit

Permalink
Bug fixes. Lets data source to customize OutputCommitter rather than …
Browse files Browse the repository at this point in the history
…OutputFormat
  • Loading branch information
liancheng committed May 12, 2015
1 parent 54c3d7b commit 5f423d3
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@ package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path

import org.apache.spark.Logging
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructType, UTF8String, StringType}
import org.apache.spark.sql.{Row, Strategy, execution, sources}
import org.apache.spark.sql._

/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
private[sql] object DataSourceStrategy extends Strategy {
private[sql] object DataSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
Expand All @@ -59,7 +61,14 @@ private[sql] object DataSourceStrategy extends Strategy {
// Scanning partitioned FSBasedRelation
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation))
if t.partitionSpec.partitionColumns.nonEmpty =>
val selectedPartition = prunePartitions(filters, t.partitionSpec).toArray
val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray

logInfo {
val total = t.partitionSpec.partitions.length
val selected = selectedPartitions.length
val percentPruned = (1 - total.toDouble / selected.toDouble) * 100
s"Selected $selected partitions out of $total, pruned $percentPruned% partitions."
}

// Only pushes down predicates that do not reference partition columns.
val pushedFilters = {
Expand All @@ -75,7 +84,7 @@ private[sql] object DataSourceStrategy extends Strategy {
projectList,
pushedFilters,
t.partitionSpec.partitionColumns,
selectedPartition) :: Nil
selectedPartitions) :: Nil

// Scanning non-partitioned FSBasedRelation
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) =>
Expand All @@ -98,6 +107,12 @@ private[sql] object DataSourceStrategy extends Strategy {
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil

case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: FSBasedRelation), part, query, overwrite, false) if part.isEmpty =>
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
execution.ExecutedCommand(
InsertIntoFSBasedRelation(t, query, Array.empty[String], mode)) :: Nil

case _ => Nil
}

Expand All @@ -109,7 +124,6 @@ private[sql] object DataSourceStrategy extends Strategy {
partitions: Array[Partition]) = {
val output = projections.map(_.toAttribute)
val relation = logicalRelation.relation.asInstanceOf[FSBasedRelation]
val dataSchema = relation.dataSchema

// Builds RDD[Row]s for each selected partition.
val perPartitionRows = partitions.map { case Partition(partitionValues, dir) =>
Expand All @@ -136,9 +150,11 @@ private[sql] object DataSourceStrategy extends Strategy {
projections,
filters,
(requiredColumns, filters) => {
// Only columns appear in actual data, which possibly include some partition column(s)
// Don't require any partition columns to save I/O. Note that here we are being
// optimistic and assuming partition columns data stored in data files are always
// consistent with those encoded in partition directory paths.
relation.buildScan(
requiredColumns.filter(dataSchema.fieldNames.contains),
requiredColumns.filterNot(partitionColumns.fieldNames.contains),
filters,
dataFilePaths)
})
Expand All @@ -147,8 +163,10 @@ private[sql] object DataSourceStrategy extends Strategy {
mergePartitionValues(output, partitionValues, scan)
}

val unionedRows =
perPartitionRows.reduceOption(_ ++ _).getOrElse(relation.sqlContext.emptyResult)
val unionedRows = perPartitionRows.reduceOption(_ ++ _).getOrElse {
relation.sqlContext.emptyResult
}

createPhysicalRDD(logicalRelation.relation, output, unionedRows)
}

Expand Down
57 changes: 32 additions & 25 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.util.Date

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat}
Expand Down Expand Up @@ -71,11 +70,15 @@ private[sql] case class InsertIntoFSBasedRelation(

val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
val outputPath = new Path(relation.paths.head)

val fs = outputPath.getFileSystem(hadoopConf)
val doInsertion = (mode, fs.exists(outputPath)) match {
val qualifiedOutputPath = fs.makeQualified(outputPath)

val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match {
case (SaveMode.ErrorIfExists, true) =>
sys.error(s"path $outputPath already exists.")
sys.error(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
fs.delete(qualifiedOutputPath, true)
true
case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
case (SaveMode.Ignore, exists) =>
Expand All @@ -86,9 +89,7 @@ private[sql] case class InsertIntoFSBasedRelation(
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[Row])
FileOutputFormat.setOutputPath(job, outputPath)

val jobConf: Configuration = ContextUtil.getConfiguration(job)
FileOutputFormat.setOutputPath(job, qualifiedOutputPath)

val df = sqlContext.createDataFrame(
DataFrame(sqlContext, query).queryExecution.toRdd,
Expand All @@ -110,8 +111,9 @@ private[sql] case class InsertIntoFSBasedRelation(
private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
try {
writerContainer.driverSideSetup()
df.sqlContext.sparkContext.runJob(df.rdd, writeRows _)
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
writerContainer.abortJob()
throw new SparkException("Job aborted.", cause)
Expand Down Expand Up @@ -166,14 +168,15 @@ private[sql] case class InsertIntoFSBasedRelation(
val partitionDF = df.select(partitionColumns.head, partitionColumns.tail: _*)
val dataDF = df.select(dataCols.head, dataCols.tail: _*)

(partitionDF.rdd, dataDF.rdd)
partitionDF.queryExecution.executedPlan.execute() ->
dataDF.queryExecution.executedPlan.execute()
}

try {
writerContainer.driverSideSetup()
sqlContext.sparkContext.runJob(partitionRDD.zip(dataRDD), writeRows _)
writerContainer.commitJob()
relation.refreshPartitions()
relation.refresh()
} catch { case cause: Throwable =>
logError("Aborting job.", cause)
writerContainer.abortJob()
Expand Down Expand Up @@ -217,7 +220,7 @@ private[sql] abstract class BaseWriterContainer(
@transient private var taskAttemptId: TaskAttemptID = _
@transient protected var taskAttemptContext: TaskAttemptContext = _

protected val outputPath = {
protected val outputPath: String = {
assert(
relation.paths.length == 1,
s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
Expand All @@ -226,37 +229,41 @@ private[sql] abstract class BaseWriterContainer(

protected val dataSchema = relation.dataSchema

protected val outputFormatClass: Class[_ <: OutputFormat[Void, Row]] = relation.outputFormatClass
protected val outputCommitterClass: Class[_ <: FileOutputCommitter] =
relation.outputCommitterClass

protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass

def driverSideSetup(): Unit = {
setupIDs(0, 0, 0)
setupConf()
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
val outputFormat = relation.outputFormatClass.newInstance()
outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) match {
case c: FileOutputCommitter => c
case _ => sys.error(
s"Output committer must be ${classOf[FileOutputCommitter].getName} or its subclasses")
}
outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext)
outputCommitter.setupJob(jobContext)
}

def executorSideSetup(taskContext: TaskContext): Unit = {
setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
setupConf()
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
val outputFormat = outputFormatClass.newInstance()
outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) match {
case c: FileOutputCommitter => c
case _ => sys.error(
s"Output committer must be ${classOf[FileOutputCommitter].getName} or its subclasses")
}
outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext)
outputCommitter.setupTask(taskAttemptContext)
initWriters()
}

private def newOutputCommitter(
clazz: Class[_ <: FileOutputCommitter],
path: String,
context: TaskAttemptContext): FileOutputCommitter = {
val ctor = outputCommitterClass.getConstructor(classOf[Path], classOf[TaskAttemptContext])
ctor.setAccessible(true)

val hadoopPath = new Path(path)
val fs = hadoopPath.getFileSystem(serializableConf.value)
val qualified = fs.makeQualified(hadoopPath)
ctor.newInstance(qualified, context)
}

private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
this.jobId = SparkHadoopWriter.createJobID(new Date, jobId)
this.taskId = new TaskID(this.jobId, true, splitId)
Expand Down Expand Up @@ -305,7 +312,7 @@ private[sql] class DefaultWriterContainer(
@transient private var writer: OutputWriter = _

override protected def initWriters(): Unit = {
writer = relation.outputWriterClass.newInstance()
writer = outputWriterClass.newInstance()
writer.init(outputCommitter.getWorkPath.toString, dataSchema, taskAttemptContext)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.deploy.SparkHadoopUtil
Expand Down Expand Up @@ -325,13 +325,17 @@ abstract class FSBasedRelation private[sql](

private[sql] def partitionSpec: PartitionSpec = _partitionSpec

private[sql] def refreshPartitions(): Unit = {
private[sql] def refresh(): Unit = {
refreshPartitions()
}

private def refreshPartitions(): Unit = {
_partitionSpec = maybePartitionSpec.getOrElse {
val basePaths = paths.map(new Path(_))
val leafDirs = basePaths.flatMap { path =>
val fs = path.getFileSystem(hadoopConf)
if (fs.exists(path)) {
SparkHadoopUtil.get.listLeafDirStatuses(fs, path)
SparkHadoopUtil.get.listLeafDirStatuses(fs, fs.makeQualified(path))
} else {
Seq.empty[FileStatus]
}
Expand All @@ -349,7 +353,7 @@ abstract class FSBasedRelation private[sql](
* Schema of this relation. It consists of [[dataSchema]] and all partition columns not appeared
* in [[dataSchema]].
*/
override val schema: StructType = {
override lazy val schema: StructType = {
val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column =>
dataSchemaColumnNames.contains(column.name.toLowerCase)
Expand Down Expand Up @@ -411,7 +415,11 @@ abstract class FSBasedRelation private[sql](
buildScan(requiredColumns, inputPaths)
}

def outputFormatClass: Class[_ <: FileOutputFormat[Void, Row]]
/**
* The output committer class to use. Default to
* [[org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter]].
*/
def outputCommitterClass: Class[_ <: FileOutputCommitter] = classOf[FileOutputCommitter]

/**
* This method is responsible for producing a new [[OutputWriter]] for each newly opened output
Expand Down
10 changes: 7 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
}

case i @ logical.InsertIntoTable(
l: LogicalRelation, partition, query, overwrite, ifNotExists)
if !l.isInstanceOf[InsertableRelation] =>
case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) =>
// OK

case logical.InsertIntoTable(LogicalRelation(_: FSBasedRelation), _, _, _, _) =>
// OK

case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) =>
// The relation in l is not an InsertableRelation.
failAnalysis(s"$l does not allow insertion.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.mutable
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.scalatest.BeforeAndAfter

import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -108,11 +107,6 @@ class SimpleFSBasedRelation
}

override def outputWriterClass: Class[_ <: OutputWriter] = classOf[SimpleOutputWriter]

override def outputFormatClass: Class[_ <: FileOutputFormat[Void, Row]] = {
// This is just a mock, not used within this test suite.
classOf[TextOutputFormat[Void, Row]]
}
}

object TestResult {
Expand Down

0 comments on commit 5f423d3

Please sign in to comment.