diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 38ae68006e189..e25b818ceaf53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -229,16 +229,17 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val outputCommitterClass: Class[_ <: FileOutputCommitter] = - relation.outputCommitterClass - protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass + private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ + def driverSideSetup(): Unit = { setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext) + relation.prepareForWrite(job) + outputFormatClass = job.getOutputFormatClass + outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) } @@ -246,22 +247,17 @@ private[sql] abstract class BaseWriterContainer( setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - outputCommitter = newOutputCommitter(outputCommitterClass, outputPath, taskAttemptContext) + outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupTask(taskAttemptContext) initWriters() } - private def newOutputCommitter( - clazz: Class[_ <: FileOutputCommitter], - path: String, - context: TaskAttemptContext): FileOutputCommitter = { - val ctor = outputCommitterClass.getConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.setAccessible(true) - - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(serializableConf.value) - val qualified = fs.makeQualified(hadoopPath) - ctor.newInstance(qualified, context) + private def newOutputCommitter(context: TaskAttemptContext): FileOutputCommitter = { + outputFormatClass.newInstance().getOutputCommitter(context) match { + case f: FileOutputCommitter => f + case f => sys.error( + s"FileOutputCommitter or its subclass is expected, but got a ${f.getClass.getName}.") + } } private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e46e9516bc203..e42384c4cee32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sources import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -432,10 +432,14 @@ abstract class FSBasedRelation private[sql]( } /** - * The output committer class to use. Default to - * [[org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter]]. + * Client side preparation for data writing can be put here. For example, user defined output + * committer can be configured here. + * + * Note that the only side effect expected here is mutating `job` via its setters. Especially, + * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states + * may cause unexpected behaviors. */ - def outputCommitterClass: Class[_ <: FileOutputCommitter] = classOf[FileOutputCommitter] + def prepareForWrite(job: Job): Unit = () /** * This method is responsible for producing a new [[OutputWriter]] for each newly opened output