Skip to content

Commit

Permalink
Resorts to new Hadoop API, and now FSBasedRelation can customize outp…
Browse files Browse the repository at this point in the history
…ut format class
  • Loading branch information
liancheng committed May 12, 2015
1 parent f320766 commit 0bc6ad1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 44 deletions.
86 changes: 47 additions & 39 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources

import java.util
import java.util.Date

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred._
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.util.Shell
import parquet.hadoop.util.ContextUtil

import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

private[sql] case class InsertIntoDataSource(
logicalRelation: LogicalRelation,
Expand Down Expand Up @@ -79,21 +83,23 @@ private[sql] case class InsertIntoFSBasedRelation(
}

if (doInsertion) {
val jobConf = new JobConf(hadoopConf)
jobConf.setOutputKeyClass(classOf[Void])
jobConf.setOutputValueClass(classOf[Row])
FileOutputFormat.setOutputPath(jobConf, outputPath)
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[Row])
FileOutputFormat.setOutputPath(job, outputPath)

val jobConf: Configuration = ContextUtil.getConfiguration(job)

val df = sqlContext.createDataFrame(
DataFrame(sqlContext, query).queryExecution.toRdd,
relation.schema,
needsConversion = false)

if (partitionColumns.isEmpty) {
insert(new DefaultWriterContainer(relation, jobConf), df)
insert(new DefaultWriterContainer(relation, job), df)
} else {
val writerContainer = new DynamicPartitionWriterContainer(
relation, jobConf, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
insertWithDynamicPartitions(writerContainer, df, partitionColumns)
}
}
Expand Down Expand Up @@ -169,6 +175,7 @@ private[sql] case class InsertIntoFSBasedRelation(
writerContainer.commitJob()
relation.refreshPartitions()
} catch { case cause: Throwable =>
logError("Aborting job.", cause)
writerContainer.abortJob()
throw new SparkException("Job aborted.", cause)
}
Expand All @@ -193,24 +200,22 @@ private[sql] case class InsertIntoFSBasedRelation(

private[sql] abstract class BaseWriterContainer(
@transient val relation: FSBasedRelation,
@transient jobConf: JobConf)
extends SparkHadoopMapRedUtil
@transient job: Job)
extends SparkHadoopMapReduceUtil
with Logging
with Serializable {

protected val serializableJobConf = new SerializableWritable(jobConf)
protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job))

// This is only used on driver side.
@transient private var jobContext: JobContext = _

// This is only used on executor side.
@transient private var taskAttemptContext: TaskAttemptContext = _
@transient private var jobContext: JobContext = job

// The following fields are initialized and used on both driver and executor side.
@transient private var outputCommitter: OutputCommitter = _
@transient private var jobId: JobID = _
@transient private var taskId: TaskID = _
@transient private var taskAttemptId: TaskAttemptID = _
@transient private var taskAttemptContext: TaskAttemptContext = _

protected val outputPath = {
assert(
Expand All @@ -221,22 +226,25 @@ private[sql] abstract class BaseWriterContainer(

protected val dataSchema = relation.dataSchema

protected val outputFormatClass: Class[_ <: OutputFormat[Void, Row]] = relation.outputFormatClass

protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass

def driverSideSetup(): Unit = {
setupIDs(0, 0, 0)
relation.prepareForWrite(serializableJobConf.value)
setupJobConf()
jobContext = newJobContext(jobConf, jobId)
outputCommitter = jobConf.getOutputCommitter
setupConf()
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
val outputFormat = relation.outputFormatClass.newInstance()
outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext)
outputCommitter.setupJob(jobContext)
}

def executorSideSetup(taskContext: TaskContext): Unit = {
setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
setupJobConf()
taskAttemptContext = newTaskAttemptContext(serializableJobConf.value, taskAttemptId)
outputCommitter = serializableJobConf.value.getOutputCommitter
setupConf()
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
val outputFormat = outputFormatClass.newInstance()
outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext)
outputCommitter.setupTask(taskAttemptContext)
initWriters()
}
Expand All @@ -247,20 +255,20 @@ private[sql] abstract class BaseWriterContainer(
this.taskAttemptId = new TaskAttemptID(taskId, attemptId)
}

private def setupJobConf(): Unit = {
serializableJobConf.value.set("mapred.job.id", jobId.toString)
serializableJobConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
serializableJobConf.value.set("mapred.task.id", taskAttemptId.toString)
serializableJobConf.value.setBoolean("mapred.task.is.map", true)
serializableJobConf.value.setInt("mapred.task.partition", 0)
private def setupConf(): Unit = {
serializableConf.value.set("mapred.job.id", jobId.toString)
serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
serializableConf.value.setBoolean("mapred.task.is.map", true)
serializableConf.value.setInt("mapred.task.partition", 0)
}

// Called on executor side when writing rows
def outputWriterForRow(row: Row): OutputWriter

protected def initWriters(): Unit = {
val writer = outputWriterClass.newInstance()
writer.init(outputPath, dataSchema, serializableJobConf.value)
writer.init(outputPath, dataSchema, serializableConf.value)
mutable.Map(outputPath -> writer)
}

Expand All @@ -280,21 +288,21 @@ private[sql] abstract class BaseWriterContainer(
}

def abortJob(): Unit = {
outputCommitter.abortJob(jobContext, JobStatus.FAILED)
outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
logError(s"Job $jobId aborted.")
}
}

private[sql] class DefaultWriterContainer(
@transient relation: FSBasedRelation,
@transient conf: JobConf)
extends BaseWriterContainer(relation, conf) {
@transient job: Job)
extends BaseWriterContainer(relation, job) {

@transient private var writer: OutputWriter = _

override protected def initWriters(): Unit = {
writer = relation.outputWriterClass.newInstance()
writer.init(outputPath, dataSchema, serializableJobConf.value)
writer.init(outputPath, dataSchema, serializableConf.value)
}

override def outputWriterForRow(row: Row): OutputWriter = writer
Expand All @@ -312,10 +320,10 @@ private[sql] class DefaultWriterContainer(

private[sql] class DynamicPartitionWriterContainer(
@transient relation: FSBasedRelation,
@transient conf: JobConf,
@transient job: Job,
partitionColumns: Array[String],
defaultPartitionName: String)
extends BaseWriterContainer(relation, conf) {
extends BaseWriterContainer(relation, job) {

// All output writers are created on executor side.
@transient protected var outputWriters: mutable.Map[String, OutputWriter] = _
Expand All @@ -338,7 +346,7 @@ private[sql] class DynamicPartitionWriterContainer(
outputWriters.getOrElseUpdate(partitionPath, {
val path = new Path(outputPath, partitionPath.stripPrefix(Path.SEPARATOR))
val writer = outputWriterClass.newInstance()
writer.init(path.toString, dataSchema, serializableJobConf.value)
writer.init(path.toString, dataSchema, serializableConf.value)
writer
})
}
Expand All @@ -356,7 +364,7 @@ private[sql] class DynamicPartitionWriterContainer(

private[sql] object DynamicPartitionWriterContainer {
val charToEscape = {
val bitSet = new util.BitSet(128)
val bitSet = new java.util.BitSet(128)

/**
* ASCII 01-1F are HTTP control characters that need to be escaped.
Expand All @@ -379,7 +387,7 @@ private[sql] object DynamicPartitionWriterContainer {
}

def needsEscaping(c: Char): Boolean = {
c >= 0 && c < charToEscape.size() && charToEscape.get(c);
c >= 0 && c < charToEscape.size() && charToEscape.get(c)
}

def escapePathName(path: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.sources

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.{OutputFormat, OutputCommitter}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.deploy.SparkHadoopUtil
Expand Down Expand Up @@ -363,7 +365,7 @@ abstract class FSBasedRelation private[sql](
/**
* For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
* this relation. For partitioned relations, this method is called for each selected partition,
* and builds an `RDD[Row]` containg all rows within that single partition.
* and builds an `RDD[Row]` containing all rows within that single partition.
*
* @param inputPaths For a non-partitioned relation, it contains paths of all data files in the
* relation. For a partitioned relation, it contains paths of all data files in a single
Expand All @@ -377,7 +379,7 @@ abstract class FSBasedRelation private[sql](
/**
* For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
* this relation. For partitioned relations, this method is called for each selected partition,
* and builds an `RDD[Row]` containg all rows within that single partition.
* and builds an `RDD[Row]` containing all rows within that single partition.
*
* @param requiredColumns Required columns.
* @param inputPaths For a non-partitioned relation, it contains paths of all data files in the
Expand All @@ -391,7 +393,7 @@ abstract class FSBasedRelation private[sql](
/**
* For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within
* this relation. For partitioned relations, this method is called for each selected partition,
* and builds an `RDD[Row]` containg all rows within that single partition.
* and builds an `RDD[Row]` containing all rows within that single partition.
*
* @param requiredColumns Required columns.
* @param filters Candidate filters to be pushed down. The actual filter should be the conjunction
Expand All @@ -409,7 +411,7 @@ abstract class FSBasedRelation private[sql](
buildScan(requiredColumns, inputPaths)
}

def prepareForWrite(conf: Configuration): Unit
def outputFormatClass: Class[_ <: OutputFormat[Void, Row]]

/**
* This method is responsible for producing a new [[OutputWriter]] for each newly opened output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable
import com.google.common.base.Objects
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.scalatest.BeforeAndAfter

import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -108,7 +110,10 @@ class SimpleFSBasedRelation

override def outputWriterClass: Class[_ <: OutputWriter] = classOf[SimpleOutputWriter]

override def prepareForWrite(conf: Configuration): Unit = ()
override def outputFormatClass: Class[_ <: OutputFormat[Void, Row]] = {
// This is just a mock, not used within this test suite.
classOf[TextOutputFormat[Void, Row]]
}
}

object TestResult {
Expand Down

0 comments on commit 0bc6ad1

Please sign in to comment.