Skip to content

Commit

Permalink
[SPARK-35215][SQL] Update custom metric per certain rows and at the e…
Browse files Browse the repository at this point in the history
…nd of the task

### What changes were proposed in this pull request?

This patch changes custom metric updating to update per certain rows (currently 100), instead of per row.

### Why are the changes needed?

Based on previous discussion #31451 (comment), we should only update custom metrics per certain (e.g. 100) rows and also at the end of the task. Updating per row doesn't make too much benefit.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing unit test.

Closes #32330 from viirya/metric-update.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
viirya authored and cloud-fan committed May 6, 2021
1 parent c6d3f37 commit 6cd5cf5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.vectorized.ColumnarBatch

class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition)
Expand Down Expand Up @@ -66,7 +66,12 @@ class DataSourceRDD(
new PartitionIterator[InternalRow](rowReader, customMetrics))
(iter, rowReader)
}
context.addTaskCompletionListener[Unit](_ => reader.close())
context.addTaskCompletionListener[Unit] { _ =>
// In case of early stopping before consuming the entire iterator,
// we need to do one more metric update at the end of the task.
CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
reader.close()
}
// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]])
}
Expand All @@ -81,6 +86,8 @@ private class PartitionIterator[T](
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
private[this] var valuePrepared = false

private var numRow = 0L

override def hasNext: Boolean = {
if (!valuePrepared) {
valuePrepared = reader.next()
Expand All @@ -92,12 +99,10 @@ private class PartitionIterator[T](
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
}
reader.currentMetricsValues.foreach { metric =>
assert(customMetrics.contains(metric.name()),
s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " +
s"${metric.name()}")
customMetrics(metric.name()).set(metric.value())
if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics)
}
numRow += 1
valuePrepared = false
reader.get()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.sql.execution.metric

import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}

object CustomMetrics {
private[spark] val V2_CUSTOM = "v2Custom"

private[spark] val NUM_ROWS_PER_UPDATE = 100

/**
* Given a class name, builds and returns a metric type for a V2 custom metric class
* `CustomMetric`.
Expand All @@ -41,4 +43,15 @@ object CustomMetrics {
None
}
}

/**
* Updates given custom metrics.
*/
def updateMetrics(
currentMetricsValues: Seq[CustomTaskMetric],
customMetrics: Map[String, SQLMetric]): Unit = {
currentMetricsValues.foreach { metric =>
customMetrics(metric.name()).set(metric.value())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -92,10 +92,13 @@ class ContinuousDataSourceRDD(

val partitionReader = readerForPartition.getPartitionReader()
new NextIterator[InternalRow] {
private var numRow = 0L

override def getNext(): InternalRow = {
partitionReader.currentMetricsValues.foreach { metric =>
customMetrics(metric.name()).set(metric.value())
if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(partitionReader.currentMetricsValues, customMetrics)
}
numRow += 1
readerForPartition.next() match {
case null =>
finished = true
Expand Down

0 comments on commit 6cd5cf5

Please sign in to comment.