From f18dec2a858b6f662ca7ba3e87b1363e1e790996 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 30 Apr 2015 00:54:18 +0800 Subject: [PATCH] More strict schema checking --- .../apache/spark/sql/sources/commands.scala | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3e97506ba2e25..ee3b9cab5e0fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -132,28 +132,32 @@ private[sql] case class InsertIntoFSBasedRelation( df: DataFrame, partitionColumns: Array[String]): Unit = { + require( + df.schema == relation.schema, + s"""DataFrame must have the same schema as the relation to which is inserted. + |DataFrame schema: ${df.schema} + |Relation schema: ${relation.schema} + """.stripMargin) + val sqlContext = df.sqlContext val (partitionRDD, dataRDD) = { val fieldNames = relation.schema.fieldNames - val (partitionCols, dataCols) = fieldNames.partition(partitionColumns.contains) + val dataCols = fieldNames.filterNot(partitionColumns.contains) val df = sqlContext.createDataFrame( DataFrame(sqlContext, query).queryExecution.toRdd, relation.schema, needsConversion = false) - assert( - partitionCols.sameElements(partitionColumns), { - val insertionPartitionCols = partitionColumns.mkString(",") - val relationPartitionCols = - relation.partitionSpec.partitionColumns.fieldNames.mkString(",") - s"""Partition columns mismatch. - |Expected: $relationPartitionCols - |Actual: $insertionPartitionCols - """.stripMargin - }) - - val partitionDF = df.select(partitionCols.head, partitionCols.tail: _*) + val partitionColumnsInSpec = relation.partitionSpec.partitionColumns.map(_.name) + require( + partitionColumnsInSpec.sameElements(partitionColumns), + s"""Partition columns mismatch. + |Expected: ${partitionColumnsInSpec.mkString(", ")} + |Actual: ${partitionColumns.mkString(", ")} + """.stripMargin) + + val partitionDF = df.select(partitionColumns.head, partitionColumns.tail: _*) val dataDF = df.select(dataCols.head, dataCols.tail: _*) (partitionDF.rdd, dataDF.rdd)