Skip to content

Commit

Permalink
Uses projection to separate partition columns and data columns while …
Browse files Browse the repository at this point in the history
…inserting rows
  • Loading branch information
liancheng committed May 12, 2015
1 parent 795920a commit bc3f9b4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 41 deletions.
96 changes: 60 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
Expand Down Expand Up @@ -102,14 +103,18 @@ private[sql] case class InsertIntoFSBasedRelation(
} else {
val writerContainer = new DynamicPartitionWriterContainer(
relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
insertWithDynamicPartitions(writerContainer, df, partitionColumns)
insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
}
}

Seq.empty[Row]
}

private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
// Uses local vals for serialization
val needsConversion = relation.needConversion
val dataSchema = relation.dataSchema

try {
writerContainer.driverSideSetup()
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
Expand All @@ -124,8 +129,8 @@ private[sql] case class InsertIntoFSBasedRelation(
writerContainer.executorSideSetup(taskContext)

try {
if (relation.needConversion) {
val converter = CatalystTypeConverters.createToScalaConverter(relation.dataSchema)
if (needsConversion) {
val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
while (iterator.hasNext) {
val row = converter(iterator.next()).asInstanceOf[Row]
writerContainer.outputWriterForRow(row).write(row)
Expand All @@ -145,9 +150,13 @@ private[sql] case class InsertIntoFSBasedRelation(
}

private def insertWithDynamicPartitions(
sqlContext: SQLContext,
writerContainer: BaseWriterContainer,
df: DataFrame,
partitionColumns: Array[String]): Unit = {
// Uses a local val for serialization
val needsConversion = relation.needConversion
val dataSchema = relation.dataSchema

require(
df.schema == relation.schema,
Expand All @@ -156,34 +165,21 @@ private[sql] case class InsertIntoFSBasedRelation(
|Relation schema: ${relation.schema}
""".stripMargin)

val sqlContext = df.sqlContext

val (partitionRDD, dataRDD) = {
val fieldNames = relation.schema.fieldNames
val dataCols = fieldNames.filterNot(partitionColumns.contains)
val df = sqlContext.createDataFrame(
DataFrame(sqlContext, query).queryExecution.toRdd,
relation.schema,
needsConversion = false)

val partitionColumnsInSpec = relation.partitionSpec.partitionColumns.map(_.name)
require(
partitionColumnsInSpec.sameElements(partitionColumns),
s"""Partition columns mismatch.
|Expected: ${partitionColumnsInSpec.mkString(", ")}
|Actual: ${partitionColumns.mkString(", ")}
""".stripMargin)

val partitionDF = df.select(partitionColumns.head, partitionColumns.tail: _*)
val dataDF = df.select(dataCols.head, dataCols.tail: _*)
val partitionColumnsInSpec = relation.partitionColumns.fieldNames
require(
partitionColumnsInSpec.sameElements(partitionColumns),
s"""Partition columns mismatch.
|Expected: ${partitionColumnsInSpec.mkString(", ")}
|Actual: ${partitionColumns.mkString(", ")}
""".stripMargin)

partitionDF.queryExecution.executedPlan.execute() ->
dataDF.queryExecution.executedPlan.execute()
}
val output = df.queryExecution.executedPlan.output
val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
val codegenEnabled = df.sqlContext.conf.codegenEnabled

try {
writerContainer.driverSideSetup()
sqlContext.sparkContext.runJob(partitionRDD.zip(dataRDD), writeRows _)
df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
Expand All @@ -192,20 +188,44 @@ private[sql] case class InsertIntoFSBasedRelation(
throw new SparkException("Job aborted.", cause)
}

def writeRows(taskContext: TaskContext, iterator: Iterator[(Row, Row)]): Unit = {
def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
writerContainer.executorSideSetup(taskContext)

try {
val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
val dataProj = newProjection(codegenEnabled, dataOutput, output)

if (needsConversion) {
val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
while (iterator.hasNext) {
val (partitionValues, data) = iterator.next()
writerContainer.outputWriterForRow(partitionValues).write(data)
val row = converter(iterator.next()).asInstanceOf[Row]
val partitionPart = partitionProj(row)
val dataPart = dataProj(row)
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
}
} else {
while (iterator.hasNext) {
val row = iterator.next()
val partitionPart = partitionProj(row)
val dataPart = dataProj(row)
writerContainer.outputWriterForRow(partitionPart).write(dataPart)
}

writerContainer.commitTask()
} catch { case cause: Throwable =>
writerContainer.abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}

writerContainer.commitTask()
}
}

// This is copied from SparkPlan, probably should move this to a more general place.
private def newProjection(
codegenEnabled: Boolean,
expressions: Seq[Expression],
inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
}
}
}
Expand Down Expand Up @@ -379,6 +399,10 @@ private[sql] class DynamicPartitionWriterContainer(
}

private[sql] object DynamicPartitionWriterContainer {
//////////////////////////////////////////////////////////////////////////////////////////////////
// The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils).
//////////////////////////////////////////////////////////////////////////////////////////////////

val charToEscape = {
val bitSet = new java.util.BitSet(128)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,19 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW
}

class SimpleTextOutputWriter extends OutputWriter {
private var converter: Any => Any = _
private var recordWriter: RecordWriter[NullWritable, Text] = _
private var taskAttemptContext: TaskAttemptContext = _

override def init(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): Unit = {
converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)
taskAttemptContext = context
}

override def write(row: Row): Unit = {
// Serializes values in `row` into a comma separated string
val convertedRow = converter(row).asInstanceOf[Row]
val serialized = convertedRow.toSeq.map(_.toString).mkString(",")
val serialized = row.toSeq.map(_.toString).mkString(",")
recordWriter.write(null, new Text(serialized))
}

Expand Down

0 comments on commit bc3f9b4

Please sign in to comment.