Skip to content

Commit

Permalink
make DatasetBenchmark fairer for Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Dec 23, 2016
1 parent 150d26c commit 43da0af
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object DecimalLiteral {
/**
* In order to do type checking, use Literal.create() instead of constructor
*/
case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback {
case class Literal (value: Any, dataType: DataType) extends LeafExpression {

override def foldable: Boolean = true
override def nullable: Boolean = value == null
Expand Down Expand Up @@ -271,45 +271,28 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression with
ev.isNull = "true"
ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};")
} else {
dataType match {
case BooleanType =>
ev.isNull = "false"
ev.value = value.toString
ev.copy("")
ev.isNull = "false"
ev.value = dataType match {
case BooleanType | IntegerType | DateType => value.toString
case FloatType =>
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
super[CodegenFallback].doGenCode(ctx, ev)
ctx.addReferenceObj(v)
} else {
ev.isNull = "false"
ev.value = s"${value}f"
ev.copy("")
s"${value}f"
}
case DoubleType =>
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
super[CodegenFallback].doGenCode(ctx, ev)
ctx.addReferenceObj(v)
} else {
ev.isNull = "false"
ev.value = s"${value}D"
ev.copy("")
s"${value}D"
}
case ByteType | ShortType =>
ev.isNull = "false"
ev.value = s"(${ctx.javaType(dataType)})$value"
ev.copy("")
case IntegerType | DateType =>
ev.isNull = "false"
ev.value = value.toString
ev.copy("")
case TimestampType | LongType =>
ev.isNull = "false"
ev.value = s"${value}L"
ev.copy("")
// eval() version may be faster for non-primitive types
case other =>
super[CodegenFallback].doGenCode(ctx, ev)
case ByteType | ShortType => s"(${ctx.javaType(dataType)})$value"
case TimestampType | LongType => s"${value}L"
case other => ctx.addReferenceObj("literal", value, ctx.javaType(dataType))
}
ev.copy("")
}
}

Expand Down
79 changes: 44 additions & 35 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
Expand All @@ -34,11 +33,13 @@ object DatasetBenchmark {
def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val df = ds.toDF("l")
val func = (l: Long) => l + 1

val benchmark = new Benchmark("back-to-back map", numRows)
val func = (d: Data) => Data(d.l + 1, d.s)

val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
Expand All @@ -53,14 +54,14 @@ object DatasetBenchmark {
var res = df
var i = 0
while (i < numChains) {
res = res.select($"l" + 1 as "l", $"s")
res = res.select($"l" + 1 as "l")
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset") { iter =>
var res = df.as[Data]
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.map(func)
Expand All @@ -75,14 +76,14 @@ object DatasetBenchmark {
def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._

val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val df = ds.toDF("l")
val func = (l: Long, i: Int) => l % (100L + i) == 0L
val funcs = 0.until(numChains).map { i => (l: Long) => func(l, i) }

val benchmark = new Benchmark("back-to-back filter", numRows)
val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
val funcs = 0.until(numChains).map { i =>
(d: Data) => func(d, i)
}

val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
Expand All @@ -104,7 +105,7 @@ object DatasetBenchmark {
}

benchmark.addCase("Dataset") { iter =>
var res = df.as[Data]
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.filter(funcs(i))
Expand Down Expand Up @@ -133,24 +134,29 @@ object DatasetBenchmark {
def aggregate(spark: SparkSession, numRows: Long): Benchmark = {
import spark.implicits._

val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val df = ds.toDF("l")

val benchmark = new Benchmark("aggregate", numRows)

val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD sum") { iter =>
rdd.aggregate(0L)(_ + _.l, _ + _)
rdd.map(l => (l % 10, l)).reduceByKey(_ + _).foreach(_ => Unit)
}

benchmark.addCase("DataFrame sum") { iter =>
df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit)
df.groupBy($"l" % 10).agg(sum($"l")).queryExecution.toRdd.foreach(_ => Unit)
}

benchmark.addCase("Dataset sum using Aggregator") { iter =>
df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit)
val result = ds.as[Long].groupByKey(_ % 10).agg(typed.sumLong[Long](identity))
result.queryExecution.toRdd.foreach(_ => Unit)
}

val complexDs = df.select($"l", $"l".cast(StringType).as("s")).as[Data]
benchmark.addCase("Dataset complex Aggregator") { iter =>
df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit)
val result = complexDs.groupByKey(_.l % 10).agg(ComplexAggregator.toColumn)
result.queryExecution.toRdd.foreach(_ => Unit)
}

benchmark
Expand All @@ -170,36 +176,39 @@ object DatasetBenchmark {
val benchmark3 = aggregate(spark, numRows)

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 3448 / 3646 29.0 34.5 1.0X
DataFrame 2647 / 3116 37.8 26.5 1.3X
Dataset 4781 / 5155 20.9 47.8 0.7X
RDD 3963 / 3976 25.2 39.6 1.0X
DataFrame 826 / 834 121.1 8.3 4.8X
Dataset 5178 / 5198 19.3 51.8 0.8X
*/
benchmark.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 1346 / 1618 74.3 13.5 1.0X
DataFrame 59 / 72 1695.4 0.6 22.8X
Dataset 2777 / 2805 36.0 27.8 0.5X
RDD 533 / 587 187.6 5.3 1.0X
DataFrame 79 / 91 1269.0 0.8 6.8X
Dataset 550 / 559 181.7 5.5 1.0X
*/
benchmark2.run()

/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD sum 1420 / 1523 70.4 14.2 1.0X
DataFrame sum 31 / 49 3214.3 0.3 45.6X
Dataset sum using Aggregator 3216 / 3257 31.1 32.2 0.4X
Dataset complex Aggregator 7948 / 8461 12.6 79.5 0.2X
RDD sum 1950 / 1995 51.3 19.5 1.0X
DataFrame sum 587 / 611 170.2 5.9 3.3X
Dataset sum using Aggregator 3014 / 3222 33.2 30.1 0.6X
Dataset complex Aggregator 32650 / 34505 3.1 326.5 0.1X
*/
benchmark3.run()
}
Expand Down

0 comments on commit 43da0af

Please sign in to comment.