diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 208e2e890b013..f9f5eee30e2df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -54,14 +54,10 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter extends OutputWriter { - private var recordWriter: RecordWriter[Void, Row] = _ - private var taskAttemptContext: TaskAttemptContext = _ - - override def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, Row] = { val conf = context.getConfiguration val outputFormat = { // When appending new Parquet files to an existing Parquet file directory, to avoid @@ -86,9 +82,8 @@ private[sql] class ParquetOutputWriter extends OutputWriter { case name if name.startsWith("_") => 0 case name if name.startsWith(".") => 0 case name => sys.error( - s"""Trying to write Parquet files to directory $outputPath, - |but found items with illegal name "$name" - """.stripMargin.replace('\n', ' ').trim) + s"Trying to write Parquet files to directory $outputPath, " + + s"but found items with illegal name '$name'.") }.reduceOption(_ max _).getOrElse(0) } else { 0 @@ -111,13 +106,12 @@ private[sql] class ParquetOutputWriter extends OutputWriter { } } - recordWriter = outputFormat.getRecordWriter(context) - taskAttemptContext = context + outputFormat.getRecordWriter(context) } override def write(row: Row): Unit = recordWriter.write(null, row) - override def close(): Unit = recordWriter.close(taskAttemptContext) + override def close(): Unit = recordWriter.close(context) } private[sql] class ParquetRelation2( @@ -175,8 +169,6 @@ private[sql] class ParquetRelation2( } } - override def outputWriterClass: Class[_ <: OutputWriter] = classOf[ParquetOutputWriter] - override def dataSchema: StructType = metadataCache.dataSchema override private[sql] def refresh(): Unit = { @@ -189,7 +181,7 @@ private[sql] class ParquetRelation2( override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum - override def prepareJobForWrite(job: Job): Unit = { + override def prepareJobForWrite(job: Job): OutputWriterFactory = { val conf = ContextUtil.getConfiguration(job) val committerClass = @@ -224,6 +216,13 @@ private[sql] class ParquetRelation2( .getOrElse( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + } } override def buildScan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 5b104eb7c4a21..a09bb08de736a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -261,7 +261,7 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass + protected var outputWriterFactory: OutputWriterFactory = _ private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ @@ -269,7 +269,7 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - relation.prepareJobForWrite(job) + outputWriterFactory = relation.prepareJobForWrite(job) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -353,9 +353,8 @@ private[sql] class DefaultWriterContainer( @transient private var writer: OutputWriter = _ override protected def initWriters(): Unit = { - writer = outputWriterClass.newInstance() taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) - writer.init(getWorkPath, dataSchema, taskAttemptContext) + writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) } override def outputWriterForRow(row: Row): OutputWriter = writer @@ -398,12 +397,10 @@ private[sql] class DynamicPartitionWriterContainer( outputWriters.getOrElseUpdate(partitionPath, { val path = new Path(getWorkPath, partitionPath) - val writer = outputWriterClass.newInstance() taskAttemptContext.getConfiguration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - writer.init(path.toString, dataSchema, taskAttemptContext) - writer + outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f672fd3588c69..770afe4899e2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -280,33 +280,42 @@ trait CatalystScan { /** * ::Experimental:: - * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. * * @since 1.4.0 */ @Experimental -abstract class OutputWriter { +trait OutputWriterFactory extends Serializable { /** - * Initializes this [[OutputWriter]] before any rows are persisted. + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. * * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that * this may not point to the final output file. For example, `FileOutputFormat` writes to * temporary directories and then merge written files back to the final destination. In * this case, `path` points to a temporary output file under the temporary directory. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the corresponding relation is partitioned. + * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. * * @since 1.4.0 */ - def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = () + def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter +} +/** + * ::Experimental:: + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriter { /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. @@ -333,8 +342,8 @@ abstract class OutputWriter { * filter using selected predicates before producing an RDD containing all matching tuples as * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file * systems, it's able to discover partitioning information from the paths of input directories, and - * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] must - * override one of the three `buildScan` methods to implement the read path. + * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] + * must override one of the three `buildScan` methods to implement the read path. * * For the write path, it provides the ability to write to both non-partitioned and partitioned * tables. Directory layout of the partitioned tables is compatible with Hive. @@ -520,8 +529,8 @@ abstract class HadoopFsRelation private[sql]( } /** - * Client side preparation for data writing can be put here. For example, user defined output - * committer can be configured here. + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here. * * Note that the only side effect expected here is mutating `job` via its setters. Especially, * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states @@ -529,13 +538,5 @@ abstract class HadoopFsRelation private[sql]( * * @since 1.4.0 */ - def prepareJobForWrite(job: Job): Unit = () - - /** - * This method is responsible for producing a new [[OutputWriter]] for each newly opened output - * file on the executor side. - * - * @since 1.4.0 - */ - def outputWriterClass: Class[_ <: OutputWriter] + def prepareJobForWrite(job: Job): OutputWriterFactory } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index af7459d2ea809..f2f77bacbc100 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -24,7 +24,7 @@ import com.google.common.base.Objects import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -59,24 +59,16 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW } } -class SimpleTextOutputWriter extends OutputWriter { - private var recordWriter: RecordWriter[NullWritable, Text] = _ - private var taskAttemptContext: TaskAttemptContext = _ - - override def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { - recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) - taskAttemptContext = context - } +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { + private val recordWriter: RecordWriter[NullWritable, Text] = + new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map(_.toString).mkString(",") recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(taskAttemptContext) + override def close(): Unit = recordWriter.close(context) } /** @@ -110,9 +102,6 @@ class SimpleTextRelation( override def hashCode(): Int = Objects.hashCode(paths, maybeDataSchema, dataSchema) - override def outputWriterClass: Class[_ <: OutputWriter] = - classOf[SimpleTextOutputWriter] - override def buildScan(inputPaths: Array[String]): RDD[Row] = { val fields = dataSchema.map(_.dataType) @@ -122,4 +111,13 @@ class SimpleTextRelation( }: _*) } } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) + } + } }