diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 0be25e9f893b6..6e69fc4247079 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -92,8 +92,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("accumulators") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - sc.parallelize(1 to 10, 10).foreach(x => accum += x) + val accum = sc.longAccumulator + sc.parallelize(1 to 10, 10).foreach(x => accum.add(x)) assert(accum.value === 55) } @@ -109,7 +109,6 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("repeatedly failing task") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) val thrown = intercept[SparkException] { // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 547da8f713ac7..19f201f606dee 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -107,13 +107,13 @@ class ReplSuite extends SparkFunSuite { test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 125686030c01f..48582c19163c9 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -150,13 +150,13 @@ class ReplSuite extends SparkFunSuite { test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 7ccc9de9db233..bd55e1a8751da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.commons.lang.StringUtils -import org.apache.spark.Accumulator import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +35,7 @@ import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{AccumulatorContext, ListAccumulator} +import org.apache.spark.util.{AccumulatorContext, ListAccumulator, LongAccumulator} private[sql] object InMemoryRelation { @@ -294,8 +293,8 @@ private[sql] case class InMemoryTableScanExec( sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes - lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions = sparkContext.longAccumulator + lazy val readBatches = sparkContext.longAccumulator private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning @@ -339,7 +338,7 @@ private[sql] case class InMemoryTableScanExec( false } else { if (enableAccumulators) { - readBatches += 1 + readBatches.add(1) } true } @@ -361,7 +360,7 @@ private[sql] case class InMemoryTableScanExec( val columnarIterator = GenerateColumnAccessor.generate(columnTypes) columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { - readPartitions += 1 + readPartitions.add(1) } columnarIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c77c889a1b7b8..f2c558ac2de7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.LongAccumulator /** * Contains methods for debugging query execution. @@ -122,13 +123,13 @@ package object debug { /** * A collection of metrics for each column of output. * - * @param elementTypes the actual runtime types for the output. Useful when there are bugs + * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) - val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) + val tupleCount: LongAccumulator = sparkContext.longAccumulator val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) @@ -149,7 +150,7 @@ package object debug { def next(): InternalRow = { val currentRow = iter.next() - tupleCount += 1 + tupleCount.add(1) var i = 0 while (i < numColumns) { val value = currentRow.get(i, output(i).dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 91d93022df377..49a0ba1f1149b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,9 +2067,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") + val countAcc = sparkContext.longAccumulator("CallCount") spark.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) + countAcc.add(1) x }) @@ -2092,7 +2092,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) + countAcc.add(1) x }) verifyCallCount( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index cf7e976acc65f..6788c9d65f6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -365,9 +365,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // This task has both accumulators that are SQL metrics and accumulators that are not. // The listener should only track the ones that are actually SQL metrics. val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") - val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") + val nonSqlMetric = sparkContext.longAccumulator("baseball") val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None) - val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None) val taskInfo = createTaskInfo(0, 0) taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null)