diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 24d8920f00278..842c51d17c4f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} import org.apache.hadoop.util.Shell import parquet.hadoop.util.ContextUtil @@ -211,7 +211,7 @@ private[sql] abstract class BaseWriterContainer( @transient private val jobContext: JobContext = job // The following fields are initialized and used on both driver and executor side. - @transient private var outputCommitter: OutputCommitter = _ + @transient protected var outputCommitter: FileOutputCommitter = _ @transient private var jobId: JobID = _ @transient private var taskId: TaskID = _ @transient private var taskAttemptId: TaskAttemptID = _ @@ -235,7 +235,11 @@ private[sql] abstract class BaseWriterContainer( setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) val outputFormat = relation.outputFormatClass.newInstance() - outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) + outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) match { + case c: FileOutputCommitter => c + case _ => sys.error( + s"Output committer must be ${classOf[FileOutputCommitter].getName} or its subclasses") + } outputCommitter.setupJob(jobContext) } @@ -244,7 +248,11 @@ private[sql] abstract class BaseWriterContainer( setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) val outputFormat = outputFormatClass.newInstance() - outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) + outputCommitter = outputFormat.getOutputCommitter(taskAttemptContext) match { + case c: FileOutputCommitter => c + case _ => sys.error( + s"Output committer must be ${classOf[FileOutputCommitter].getName} or its subclasses") + } outputCommitter.setupTask(taskAttemptContext) initWriters() } @@ -298,7 +306,7 @@ private[sql] class DefaultWriterContainer( override protected def initWriters(): Unit = { writer = relation.outputWriterClass.newInstance() - writer.init(outputPath, dataSchema, taskAttemptContext) + writer.init(outputCommitter.getWorkPath.toString, dataSchema, taskAttemptContext) } override def outputWriterForRow(row: Row): OutputWriter = writer @@ -340,7 +348,7 @@ private[sql] class DynamicPartitionWriterContainer( }.mkString outputWriters.getOrElseUpdate(partitionPath, { - val path = new Path(outputPath, partitionPath.stripPrefix(Path.SEPARATOR)) + val path = new Path(outputCommitter.getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR)) val writer = outputWriterClass.newInstance() writer.init(path.toString, dataSchema, taskAttemptContext) writer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 07e77f8a24100..6060396f59bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.mapreduce.{TaskAttemptContext, OutputFormat, OutputCommitter} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.deploy.SparkHadoopUtil @@ -411,7 +411,7 @@ abstract class FSBasedRelation private[sql]( buildScan(requiredColumns, inputPaths) } - def outputFormatClass: Class[_ <: OutputFormat[Void, Row]] + def outputFormatClass: Class[_ <: FileOutputFormat[Void, Row]] /** * This method is responsible for producing a new [[OutputWriter]] for each newly opened output diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala index d5948e66bfac1..4fc12fa04523a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.sources import scala.collection.mutable import com.google.common.base.Objects -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.{TaskAttemptContext, OutputFormat} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.scalatest.BeforeAndAfter import org.apache.spark.rdd.RDD @@ -110,7 +109,7 @@ class SimpleFSBasedRelation override def outputWriterClass: Class[_ <: OutputWriter] = classOf[SimpleOutputWriter] - override def outputFormatClass: Class[_ <: OutputFormat[Void, Row]] = { + override def outputFormatClass: Class[_ <: FileOutputFormat[Void, Row]] = { // This is just a mock, not used within this test suite. classOf[TextOutputFormat[Void, Row]] } @@ -268,7 +267,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) } } @@ -295,7 +294,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) } } @@ -328,7 +327,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 4) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 4) + assert(TestResult.writerPaths.size === 8) assert(TestResult.writtenRows === expectedRows.toSet) } } @@ -381,7 +380,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) } @@ -443,7 +442,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) } @@ -472,7 +471,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) } @@ -493,7 +492,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 8) assert(TestResult.writtenRows === expectedRows.toSet) } @@ -535,7 +534,7 @@ class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") TestResult.synchronized { - assert(TestResult.writerPaths.size === 2) + assert(TestResult.writerPaths.size === 4) assert(TestResult.writtenRows === expectedRows.toSet) }