From 252417fa21eb47781addfd614ff00dac793b52a9 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 2 Jun 2016 11:16:24 -0500 Subject: [PATCH] [SPARK-15322][SQL][FOLLOWUP] Use the new long accumulator for old int accumulators. ## What changes were proposed in this pull request? This PR corrects the remaining cases for using old accumulators. This does not change some old accumulator usages below: - `ImplicitSuite.scala` - Tests dedicated to old accumulator, for implicits with `AccumulatorParam` - `AccumulatorSuite.scala` - Tests dedicated to old accumulator - `JavaSparkContext.scala` - For supporting old accumulators for Java API. - `debug.package.scala` - Usage with `HashSet[String]`. Currently, it seems no implementation for this. I might be able to write an anonymous class for this but I didn't because I think it is not worth writing a lot of codes only for this. - `SQLMetricsSuite.scala` - This uses the old accumulator for checking type boxing. It seems new accumulator does not require type boxing for this case whereas the old one requires (due to the use of generic). ## How was this patch tested? Existing tests cover this. Author: hyukjinkwon Closes #13434 from HyukjinKwon/accum. --- .../scala/org/apache/spark/DistributedSuite.scala | 5 ++--- .../test/scala/org/apache/spark/repl/ReplSuite.scala | 6 +++--- .../test/scala/org/apache/spark/repl/ReplSuite.scala | 6 +++--- .../execution/columnar/InMemoryTableScanExec.scala | 11 +++++------ .../apache/spark/sql/execution/debug/package.scala | 7 ++++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +++--- .../spark/sql/execution/ui/SQLListenerSuite.scala | 4 ++-- 7 files changed, 22 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 0be25e9f893b6..6e69fc4247079 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -92,8 +92,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("accumulators") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - sc.parallelize(1 to 10, 10).foreach(x => accum += x) + val accum = sc.longAccumulator + sc.parallelize(1 to 10, 10).foreach(x => accum.add(x)) assert(accum.value === 55) } @@ -109,7 +109,6 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("repeatedly failing task") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) val thrown = intercept[SparkException] { // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 547da8f713ac7..19f201f606dee 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -107,13 +107,13 @@ class ReplSuite extends SparkFunSuite { test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 125686030c01f..48582c19163c9 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -150,13 +150,13 @@ class ReplSuite extends SparkFunSuite { test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 7ccc9de9db233..bd55e1a8751da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.commons.lang.StringUtils -import org.apache.spark.Accumulator import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +35,7 @@ import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{AccumulatorContext, ListAccumulator} +import org.apache.spark.util.{AccumulatorContext, ListAccumulator, LongAccumulator} private[sql] object InMemoryRelation { @@ -294,8 +293,8 @@ private[sql] case class InMemoryTableScanExec( sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes - lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions = sparkContext.longAccumulator + lazy val readBatches = sparkContext.longAccumulator private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning @@ -339,7 +338,7 @@ private[sql] case class InMemoryTableScanExec( false } else { if (enableAccumulators) { - readBatches += 1 + readBatches.add(1) } true } @@ -361,7 +360,7 @@ private[sql] case class InMemoryTableScanExec( val columnarIterator = GenerateColumnAccessor.generate(columnTypes) columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { - readPartitions += 1 + readPartitions.add(1) } columnarIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c77c889a1b7b8..f2c558ac2de7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.LongAccumulator /** * Contains methods for debugging query execution. @@ -122,13 +123,13 @@ package object debug { /** * A collection of metrics for each column of output. * - * @param elementTypes the actual runtime types for the output. Useful when there are bugs + * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) - val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) + val tupleCount: LongAccumulator = sparkContext.longAccumulator val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) @@ -149,7 +150,7 @@ package object debug { def next(): InternalRow = { val currentRow = iter.next() - tupleCount += 1 + tupleCount.add(1) var i = 0 while (i < numColumns) { val value = currentRow.get(i, output(i).dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 91d93022df377..49a0ba1f1149b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,9 +2067,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") + val countAcc = sparkContext.longAccumulator("CallCount") spark.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) + countAcc.add(1) x }) @@ -2092,7 +2092,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) + countAcc.add(1) x }) verifyCallCount( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index cf7e976acc65f..6788c9d65f6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -365,9 +365,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // This task has both accumulators that are SQL metrics and accumulators that are not. // The listener should only track the ones that are actually SQL metrics. val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") - val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val nonSqlMetric = sparkContext.longAccumulator("baseball") val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None) - val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None) val taskInfo = createTaskInfo(0, 0) taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null)