Skip to content

Commit

Permalink
[SPARK-6672][SQL] convert row to catalyst in createDataFrame(RDD[Row]…
Browse files Browse the repository at this point in the history
…, ...)

We assume that `RDD[Row]` contains Scala types. So we need to convert them into catalyst types in createDataFrame. liancheng

Author: Xiangrui Meng <[email protected]>

Closes apache#5329 from mengxr/SPARK-6672 and squashes the following commits:

2d52644 [Xiangrui Meng] set needsConversion = false in jsonRDD
06896e4 [Xiangrui Meng] add createDataFrame without conversion
4a3767b [Xiangrui Meng] convert Row to catalyst
  • Loading branch information
mengxr authored and liancheng committed Apr 2, 2015
1 parent 6562787 commit 424e987
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ trait ScalaReflection {
case (d: BigDecimal, _) => Decimal(d)
case (d: java.math.BigDecimal, _) => Decimal(d)
case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
case (r: Row, structType: StructType) =>
new GenericRow(
r.toSeq.zip(structType.fields).map { case (elem, field) =>
convertToCatalyst(elem, field.dataType)
}.toArray)
case (other, _) => other
}

Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,8 @@ class DataFrame private[sql](
*/
override def repartition(numPartitions: Int): DataFrame = {
sqlContext.createDataFrame(
queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
queryExecution.toRdd.map(_.copy()).repartition(numPartitions),
schema, needsConversion = false)
}

/**
Expand Down
20 changes: 17 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema, needsConversion = true)
}

/**
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
* converted to Catalyst rows.
*/
private[sql]
def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
val catalystRows = if (needsConversion) {
rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])
} else {
rowRDD
}
val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
DataFrame(this, logicalPlan)
}

Expand Down Expand Up @@ -604,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}

/**
Expand Down Expand Up @@ -633,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ private[sql] class DefaultSource
val df =
sqlContext.createDataFrame(
data.queryExecution.toRdd,
data.schema.asNullable)
data.schema.asNullable,
needsConversion = false)
val createdRelation =
createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource(
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = DataFrame(sqlContext, query)
// Apply the schema of the existing table to the new data.
val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
val df = sqlContext.createDataFrame(
data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
relation.insert(df, overwrite)

// Invalidate the cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
* @param y y coordinate
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
private[sql] class ExamplePoint(val x: Double, val y: Double)
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable

/**
* User-defined type for [[ExamplePoint]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.language.postfixOps

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.sql
Expand Down Expand Up @@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest {
testData.select($"*").show()
testData.select($"*").show(1000)
}

test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
val df = TestSQLContext.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
}

0 comments on commit 424e987

Please sign in to comment.