Skip to content

Commit

Permalink
[SPARK-21871][SQL] Check actual bytecode size when compiling generate…
Browse files Browse the repository at this point in the history
…d code

This pr added code to check actual bytecode size when compiling generated code. In apache#18810, we added code to give up code compilation and use interpreter execution in `SparkPlan` if the line number of generated functions goes over `maxLinesPerFunction`. But, we already have code to collect metrics for compiled bytecode size in `CodeGenerator` object. So,we could easily reuse the code for this purpose.

Added tests in `WholeStageCodegenSuite`.

Author: Takeshi Yamamuro <[email protected]>

Closes apache#19083 from maropu/SPARK-21871.
  • Loading branch information
maropu authored and rdblue committed Mar 27, 2019
1 parent 430e0b1 commit 6972389
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -898,17 +898,23 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}

object CodeGenerator extends Logging {

// This is the value of HugeMethodLimit in the OpenJDK JVM settings
val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000

/**
* Compile the Java source code into a Java class, using Janino.
*
* @return a pair of a generated class and the max bytecode size of generated functions.
*/
def compile(code: CodeAndComment): GeneratedClass = {
def compile(code: CodeAndComment): (GeneratedClass, Int) = {
cache.get(code)
}

/**
* Compile the Java source code into a Java class, using Janino.
*/
private[this] def doCompile(code: CodeAndComment): GeneratedClass = {
private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = {
val evaluator = new ClassBodyEvaluator()

// A special classloader used to wrap the actual parent classloader of
Expand Down Expand Up @@ -946,22 +952,24 @@ object CodeGenerator extends Logging {
s"\n$formatted"
})

try {
val maxCodeSize = try {
evaluator.cook("generated.java", code.body)
recordCompilationStats(evaluator)
updateAndGetCompilationStats(evaluator)
} catch {
case e: Exception =>
val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
throw new Exception(msg, e)
}
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]

(evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize)
}

/**
* Records the generated class and method bytecode sizes by inspecting janino private fields.
* Returns the max bytecode size of the generated functions by inspecting janino private fields.
* Also, this method updates the metrics information.
*/
private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = {
private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = {
// First retrieve the generated classes.
val classes = {
val resultField = classOf[SimpleCompiler].getDeclaredField("result")
Expand All @@ -976,23 +984,26 @@ object CodeGenerator extends Logging {
val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute")
val codeAttrField = codeAttr.getDeclaredField("code")
codeAttrField.setAccessible(true)
classes.foreach { case (_, classBytes) =>
val codeSizes = classes.flatMap { case (_, classBytes) =>
CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length)
try {
val cf = new ClassFile(new ByteArrayInputStream(classBytes))
cf.methodInfos.asScala.foreach { method =>
method.getAttributes().foreach { a =>
if (a.getClass.getName == codeAttr.getName) {
CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(
codeAttrField.get(a).asInstanceOf[Array[Byte]].length)
}
val stats = cf.methodInfos.asScala.flatMap { method =>
method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a =>
val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length
CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize)
byteCodeSize
}
}
Some(stats)
} catch {
case NonFatal(e) =>
logWarning("Error calculating stats of compiled class.", e)
None
}
}
}.flatten

codeSizes.max
}

/**
Expand All @@ -1007,8 +1018,8 @@ object CodeGenerator extends Logging {
private val cache = CacheBuilder.newBuilder()
.maximumSize(100)
.build(
new CacheLoader[CodeAndComment, GeneratedClass]() {
override def load(code: CodeAndComment): GeneratedClass = {
new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() {
override def load(code: CodeAndComment): (GeneratedClass, Int) = {
val startTime = System.nanoTime()
val result = doCompile(code)
val endTime = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(ctx.references.toArray).asInstanceOf[MutableProjection]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(ctx.references.toArray).asInstanceOf[BaseOrdering]
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(ctx.references.toArray).asInstanceOf[Predicate]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
val (clazz, _) = CodeGenerator.compile(code)
val resultRow = new SpecificInternalRow(expressions.map(_.dataType))
c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection]
clazz.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty))
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")

val c = CodeGenerator.compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator

////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the configuration options for Spark SQL.
Expand Down Expand Up @@ -585,6 +586,16 @@ object SQLConf {
.intConf
.createWithDefault(20)

val WHOLESTAGE_HUGE_METHOD_LIMIT = SQLConfigBuilder("spark.sql.codegen.hugeMethodLimit")
.internal()
.doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " +
"codegen. When the compiled function exceeds this threshold, " +
"the whole-stage codegen is deactivated for this subtree of the current query plan. " +
s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " +
"this is a limit in the OpenJDK JVM implementation.")
.intConf
.createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT)

val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -895,6 +906,8 @@ class SQLConf extends Serializable with Logging {

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)

def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,25 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
override def doExecute(): RDD[InternalRow] = {
val (ctx, cleanedSource) = doCodeGen()
// try to compile and fallback if it failed
try {
val (_, maxCodeSize) = try {
CodeGenerator.compile(cleanedSource)
} catch {
case e: Exception if !Utils.isTesting && sqlContext.conf.wholeStageFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
return child.execute()
}

// Check if compiled code has a too large function
if (maxCodeSize > sqlContext.conf.hugeMethodLimit) {
logWarning(s"Found too long generated codes and JIT optimization might not work: " +
s"the bytecode size was $maxCodeSize, this value went over the limit " +
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"for this plan. To avoid this, you can raise the limit " +
s"${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}:\n$treeString")
return child.execute()
}

val references = ctx.references.toArray

val durationMs = longMetric("pipelineTime")
Expand All @@ -371,7 +382,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
if (rdds.length == 1) {
rdds.head.mapPartitionsWithIndex { (index, iter) =>
WholeStageCodegenExec.this.logInfo(logMsg)
val clazz = CodeGenerator.compile(cleanedSource)
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
Expand All @@ -391,7 +402,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
}.mapPartitionsWithIndex { (index, zippedIter) =>
WholeStageCodegenExec.this.logInfo(logMsg)
val (leftIter, rightIter) = zippedIter.next()
val clazz = CodeGenerator.compile(cleanedSource)
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(leftIter, rightIter))
new Iterator[InternalRow] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}")

CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
val (clazz, _) = CodeGenerator.compile(code)
clazz.generate(Array.empty).asInstanceOf[ColumnarIterator]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
Expand Down Expand Up @@ -178,4 +179,41 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
assert(df.collect() === Array(Row(1), Row(2)))
}
}

def genGroupByCode(caseNum: Int): CodeAndComment = {
val caseExp = (1 to caseNum).map { i =>
s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i"
}.toList
val keyExp = List(
"id",
"(id & 1023) as k1",
"cast(id & 1023 as double) as k2",
"cast(id & 1023 as int) as k3")

val ds = spark.range(10)
.selectExpr(keyExp:::caseExp: _*)
.groupBy("k1", "k2", "k3")
.sum()
val plan = ds.queryExecution.executedPlan

val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => wp.child match {
case hp: HashAggregateExec if (hp.child.isInstanceOf[ProjectExec]) => true
case _ => false
}
case _ => false
})

assert(wholeStageCodeGenExec.isDefined)
wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
}

test("SPARK-21871 check if we can get large code size when compiling too long functions") {
val codeWithShortFunctions = genGroupByCode(3)
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
val codeWithLongFunctions = genGroupByCode(20)
val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.execution.vectorized.AggregateHashMap
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
Expand Down

0 comments on commit 6972389

Please sign in to comment.