Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-15214][SQL] Code-generation for Generate #13065

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5254559
CG for Generate/Explode
hvanhovell May 11, 2016
9df56a9
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell May 11, 2016
5d068b5
Fix compilation and binding errors.
hvanhovell May 12, 2016
b2d663b
Fix infinite loop on continue.
hvanhovell May 12, 2016
e04d66f
Fix outer.
hvanhovell May 12, 2016
43a04bf
Really really fix lateral outer view explode(...).
hvanhovell May 12, 2016
7b4772d
Add benchmark & fix subexpressionsuite.
hvanhovell May 17, 2016
b721b60
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell May 22, 2016
09513e7
Generate WIP
hvanhovell May 25, 2016
f7c2307
Update GenerateExec
hvanhovell May 30, 2016
f5bd9cf
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell May 30, 2016
dba4240
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell May 31, 2016
49f9e7f
Further tweaking...
hvanhovell Jun 3, 2016
b3531cb
Proper support for json_tuple.
hvanhovell Jun 3, 2016
5cfba19
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Jun 3, 2016
f86da0f
Use TraversableOnce for regular Generators.
hvanhovell Jun 7, 2016
60da24e
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Jun 7, 2016
87688b1
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Jun 9, 2016
1d2d595
Add benchmarks for explode map & json_tuple
hvanhovell Jun 9, 2016
c9b3eda
fix generated json
hvanhovell Jun 9, 2016
2732b06
disable benchmark
hvanhovell Jun 9, 2016
36cd826
Add tests for generate with outer = true
hvanhovell Jun 26, 2016
5b3d9bd
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Aug 25, 2016
3a40952
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Aug 29, 2016
c41e308
Add new generators & update.
hvanhovell Aug 29, 2016
116339a
Fix Stack
hvanhovell Aug 30, 2016
2c6c7f2
Make Stack use the iteration path.
hvanhovell Aug 30, 2016
d20114b
Update benchmarks
hvanhovell Aug 30, 2016
ad36de5
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Oct 12, 2016
757b470
Revert json_tuple changes
hvanhovell Oct 12, 2016
8c14194
Touch-ups
hvanhovell Oct 12, 2016
459714c
Touch-ups
hvanhovell Oct 12, 2016
7b7fa6e
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Nov 3, 2016
ebd9d8c
Merge remote-tracking branch 'apache-github/master' into SPARK-15214
hvanhovell Nov 15, 2016
f81eed7
Code review
hvanhovell Nov 15, 2016
29c606a
code review
hvanhovell Nov 17, 2016
af9a516
code review 2
hvanhovell Nov 18, 2016
3146cc5
Add proper fallback for 'Stack' generator. Make traversable once oute…
hvanhovell Nov 18, 2016
ffd5ef8
Minor thing
hvanhovell Nov 19, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -62,6 +64,21 @@ trait Generator extends Expression {
def terminate(): TraversableOnce[InternalRow] = Nil
}

/**
* A collection producing [[Generator]]. This trait provides a different path for code generation,
* by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
*/
trait CollectionGenerator extends Generator {
/** The position of an element within the collection should also be returned. */
def position: Boolean

/** Rows will be inlined during generation. */
def inline: Boolean

/** The schema of the returned collection object. */
def collectionSchema: DataType = dataType
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collectionType is better?

Copy link
Member

@maropu maropu Aug 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, does we need this interface? Adding a new interface makes codes more complicated, I think.

Copy link
Contributor Author

@hvanhovell hvanhovell Aug 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed a way to make sure that the collection based code path (iteration over ArrayData/MapData) can be easily identified. The other option would be to hard-code all Generators that support this code path, but that seemed just wrong to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for collectionType


/**
* A generator that produces its output using the provided lambda function.
*/
Expand All @@ -77,7 +94,9 @@ case class UserDefinedGenerator(
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
val inputSchema = StructType(children.map { e =>
StructField(e.simpleString, e.dataType, nullable = true)
})
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[InternalRow => Row]
}
Expand Down Expand Up @@ -109,8 +128,7 @@ case class UserDefinedGenerator(
1 2
3 NULL
""")
case class Stack(children: Seq[Expression])
extends Expression with Generator with CodegenFallback {
case class Stack(children: Seq[Expression]) extends Generator {

private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
Expand Down Expand Up @@ -149,29 +167,52 @@ case class Stack(children: Seq[Expression])
InternalRow(fields: _*)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Rows - we write these into an array.
val rowData = ctx.freshName("rows")
ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
val values = children.tail
val dataTypes = values.take(numFields).map(_.dataType)
val rows = for (row <- 0 until numRows) yield {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split these into multiple funcitons just in case of numRow is large?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

val fields = for (col <- 0 until numFields) yield {
val index = row * numFields + col
if (index < values.length) values(index) else Literal(null, dataTypes(col))
}
val eval = CreateStruct(fields).genCode(ctx)
s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
}

// Create the iterator.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ctx.addMutableState(
s"$wrapperClass<InternalRow>",
ev.value,
s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
ev.copy(code = rows.mkString("\n"), isNull = "false")
}
}

/**
* A base class for Explode and PosExplode
* A base class for [[Explode]] and [[PosExplode]].
*/
abstract class ExplodeBase(child: Expression, position: Boolean)
extends UnaryExpression with Generator with CodegenFallback with Serializable {
abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
override val inline: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case _: ArrayType | _: MapType =>
TypeCheckResult.TypeCheckSuccess
} else {
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function explode should be array or map type, not ${child.dataType}")
}
}

// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("pos", IntegerType, nullable = false)
.add("col", et, containsNull)
} else {
new StructType()
Expand All @@ -180,12 +221,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
case MapType(kt, vt, valueContainsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("key", kt, false)
.add("pos", IntegerType, nullable = false)
.add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
.add("key", kt, false)
.add("key", kt, nullable = false)
.add("value", vt, valueContainsNull)
}
}
Expand Down Expand Up @@ -218,6 +259,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
}
}
}

override def collectionSchema: DataType = child.dataType

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}
}

/**
Expand All @@ -239,7 +286,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
20
""")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
case class Explode(child: Expression) extends ExplodeBase {
override val position: Boolean = false
}

/**
* Given an input array produces a sequence of rows for each position and value in the array.
Expand All @@ -260,7 +309,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
1 20
""")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
case class PosExplode(child: Expression) extends ExplodeBase {
override val position = true
}

/**
* Explodes an array of structs into a table.
Expand All @@ -273,20 +324,24 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
1 a
2 b
""")
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
override val inline: Boolean = true
override val position: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
case ArrayType(st: StructType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
}

override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) => et
case ArrayType(st: StructType, _) => st
}

override def collectionSchema: DataType = child.dataType

private lazy val numFields = elementSchema.fields.length

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
Expand All @@ -298,4 +353,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
yield inputArray.getStruct(i, numFields)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types.{DataType, IntegerType}

class SubexpressionEliminationSuite extends SparkFunSuite {
test("Semantic equals and hash") {
Expand Down Expand Up @@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
test("Children of CodegenFallback") {
val one = Literal(1)
val two = Add(one, one)
val explode = Explode(two)
val add = Add(two, explode)
val fallback = CodegenFallbackExpression(two)
val add = Add(two, fallback)

var equivalence = new EquivalentExpressions
val equivalence = new EquivalentExpressions
equivalence.addExprTree(add, true)
// the `two` inside `explode` should not be added
// the `two` inside `fallback` should not be added
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
}
}

case class CodegenFallbackExpression(child: Expression)
extends UnaryExpression with CodegenFallback {
override def dataType: DataType = child.dataType
}
Loading