Skip to content

Commit

Permalink
Adds OutputWriterFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed May 14, 2015
1 parent 047d40d commit 522c24e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
}

// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[sql] class ParquetOutputWriter extends OutputWriter {
private var recordWriter: RecordWriter[Void, Row] = _
private var taskAttemptContext: TaskAttemptContext = _

override def init(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): Unit = {
private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
extends OutputWriter {

private val recordWriter: RecordWriter[Void, Row] = {
val conf = context.getConfiguration
val outputFormat = {
// When appending new Parquet files to an existing Parquet file directory, to avoid
Expand All @@ -86,9 +82,8 @@ private[sql] class ParquetOutputWriter extends OutputWriter {
case name if name.startsWith("_") => 0
case name if name.startsWith(".") => 0
case name => sys.error(
s"""Trying to write Parquet files to directory $outputPath,
|but found items with illegal name "$name"
""".stripMargin.replace('\n', ' ').trim)
s"Trying to write Parquet files to directory $outputPath, " +
s"but found items with illegal name '$name'.")
}.reduceOption(_ max _).getOrElse(0)
} else {
0
Expand All @@ -111,13 +106,12 @@ private[sql] class ParquetOutputWriter extends OutputWriter {
}
}

recordWriter = outputFormat.getRecordWriter(context)
taskAttemptContext = context
outputFormat.getRecordWriter(context)
}

override def write(row: Row): Unit = recordWriter.write(null, row)

override def close(): Unit = recordWriter.close(taskAttemptContext)
override def close(): Unit = recordWriter.close(context)
}

private[sql] class ParquetRelation2(
Expand Down Expand Up @@ -175,8 +169,6 @@ private[sql] class ParquetRelation2(
}
}

override def outputWriterClass: Class[_ <: OutputWriter] = classOf[ParquetOutputWriter]

override def dataSchema: StructType = metadataCache.dataSchema

override private[sql] def refresh(): Unit = {
Expand All @@ -189,7 +181,7 @@ private[sql] class ParquetRelation2(

override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum

override def prepareJobForWrite(job: Job): Unit = {
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = ContextUtil.getConfiguration(job)

val committerClass =
Expand Down Expand Up @@ -224,6 +216,13 @@ private[sql] class ParquetRelation2(
.getOrElse(
sqlContext.conf.parquetCompressionCodec.toUpperCase,
CompressionCodecName.UNCOMPRESSED).name())

new OutputWriterFactory {
override def newInstance(
path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
new ParquetOutputWriter(path, context)
}
}
}

override def buildScan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,15 @@ private[sql] abstract class BaseWriterContainer(

protected val dataSchema = relation.dataSchema

protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass
protected var outputWriterFactory: OutputWriterFactory = _

private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _

def driverSideSetup(): Unit = {
setupIDs(0, 0, 0)
setupConf()
taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
relation.prepareJobForWrite(job)
outputWriterFactory = relation.prepareJobForWrite(job)
outputFormatClass = job.getOutputFormatClass
outputCommitter = newOutputCommitter(taskAttemptContext)
outputCommitter.setupJob(jobContext)
Expand Down Expand Up @@ -353,9 +353,8 @@ private[sql] class DefaultWriterContainer(
@transient private var writer: OutputWriter = _

override protected def initWriters(): Unit = {
writer = outputWriterClass.newInstance()
taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath)
writer.init(getWorkPath, dataSchema, taskAttemptContext)
writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext)
}

override def outputWriterForRow(row: Row): OutputWriter = writer
Expand Down Expand Up @@ -398,12 +397,10 @@ private[sql] class DynamicPartitionWriterContainer(

outputWriters.getOrElseUpdate(partitionPath, {
val path = new Path(getWorkPath, partitionPath)
val writer = outputWriterClass.newInstance()
taskAttemptContext.getConfiguration.set(
"spark.sql.sources.output.path",
new Path(outputPath, partitionPath).toString)
writer.init(path.toString, dataSchema, taskAttemptContext)
writer
outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,33 +280,42 @@ trait CatalystScan {

/**
* ::Experimental::
* [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the
* underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor.
* An [[OutputWriter]] instance is created and initialized when a new output file is opened on
* executor side. This instance is used to persist rows to this single output file.
* A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver
* side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized
* to executor side to create actual [[OutputWriter]]s on the fly.
*
* @since 1.4.0
*/
@Experimental
abstract class OutputWriter {
trait OutputWriterFactory extends Serializable {
/**
* Initializes this [[OutputWriter]] before any rows are persisted.
* When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side
* to instantiate new [[OutputWriter]]s.
*
* @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that
* this may not point to the final output file. For example, `FileOutputFormat` writes to
* temporary directories and then merge written files back to the final destination. In
* this case, `path` points to a temporary output file under the temporary directory.
* @param dataSchema Schema of the rows to be written. Partition columns are not included in the
* schema if the corresponding relation is partitioned.
* schema if the relation being written is partitioned.
* @param context The Hadoop MapReduce task context.
*
* @since 1.4.0
*/
def init(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): Unit = ()
def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter
}

/**
* ::Experimental::
* [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the
* underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor.
* An [[OutputWriter]] instance is created and initialized when a new output file is opened on
* executor side. This instance is used to persist rows to this single output file.
*
* @since 1.4.0
*/
@Experimental
abstract class OutputWriter {
/**
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
Expand All @@ -333,8 +342,8 @@ abstract class OutputWriter {
* filter using selected predicates before producing an RDD containing all matching tuples as
* [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file
* systems, it's able to discover partitioning information from the paths of input directories, and
* perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] must
* override one of the three `buildScan` methods to implement the read path.
* perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]]
* must override one of the three `buildScan` methods to implement the read path.
*
* For the write path, it provides the ability to write to both non-partitioned and partitioned
* tables. Directory layout of the partitioned tables is compatible with Hive.
Expand Down Expand Up @@ -520,22 +529,14 @@ abstract class HadoopFsRelation private[sql](
}

/**
* Client side preparation for data writing can be put here. For example, user defined output
* committer can be configured here.
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here.
*
* Note that the only side effect expected here is mutating `job` via its setters. Especially,
* Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states
* may cause unexpected behaviors.
*
* @since 1.4.0
*/
def prepareJobForWrite(job: Job): Unit = ()

/**
* This method is responsible for producing a new [[OutputWriter]] for each newly opened output
* file on the executor side.
*
* @since 1.4.0
*/
def outputWriterClass: Class[_ <: OutputWriter]
def prepareJobForWrite(job: Job): OutputWriterFactory
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.google.common.base.Objects
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
Expand Down Expand Up @@ -59,24 +59,16 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
}
}

class SimpleTextOutputWriter extends OutputWriter {
private var recordWriter: RecordWriter[NullWritable, Text] = _
private var taskAttemptContext: TaskAttemptContext = _

override def init(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): Unit = {
recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)
taskAttemptContext = context
}
class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter {
private val recordWriter: RecordWriter[NullWritable, Text] =
new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)

override def write(row: Row): Unit = {
val serialized = row.toSeq.map(_.toString).mkString(",")
recordWriter.write(null, new Text(serialized))
}

override def close(): Unit = recordWriter.close(taskAttemptContext)
override def close(): Unit = recordWriter.close(context)
}

/**
Expand Down Expand Up @@ -110,9 +102,6 @@ class SimpleTextRelation(
override def hashCode(): Int =
Objects.hashCode(paths, maybeDataSchema, dataSchema)

override def outputWriterClass: Class[_ <: OutputWriter] =
classOf[SimpleTextOutputWriter]

override def buildScan(inputPaths: Array[String]): RDD[Row] = {
val fields = dataSchema.map(_.dataType)

Expand All @@ -122,4 +111,13 @@ class SimpleTextRelation(
}: _*)
}
}

override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context)
}
}
}

0 comments on commit 522c24e

Please sign in to comment.