Skip to content

Commit

Permalink
[SPARK-34338][SQL] Report metrics from Datasource v2 scan
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch proposes to leverage `CustomMetric`, `CustomTaskMetric` API to report custom metrics from DS v2 scan to Spark.

### Why are the changes needed?

This is related to #31398. In SPARK-34297, we want to add a couple of metrics when reading from Kafka in SS. We need some public API change in DS v2 to make it possible. This extracts only DS v2 change and make it general for DS v2 instead of micro-batch DS v2 API.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test.

Implement a simple test DS v2 class locally and run it:

```scala
scala> import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.execution.datasources.v2._

scala> classOf[CustomMetricDataSourceV2].getName
res0: String = org.apache.spark.sql.execution.datasources.v2.CustomMetricDataSourceV2

scala> val df = spark.read.format(res0).load()
df: org.apache.spark.sql.DataFrame = [i: int, j: int]

scala> df.collect
```

<img width="703" alt="Screen Shot 2021-03-30 at 11 07 13 PM" src="https://user-images.githubusercontent.com/68855/113098080-d8a49800-91ac-11eb-8681-be408a0f2e69.png">

Closes #31451 from viirya/dsv2-metrics.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
viirya authored and cloud-fan committed Apr 20, 2021
1 parent 3614448 commit eb9a439
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public interface PartitionReader<T> extends Closeable {
T get();

/**
* Returns an array of custom task metrics. By default it returns empty array.
* Returns an array of custom task metrics. By default it returns empty array. Note that it is
* not recommended to put heavy logic in this method as it may affect reading performance.
*/
default CustomTaskMetric[] currentMetricsValues() {
CustomTaskMetric[] NO_METRICS = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class BatchScanExec(
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
}

override def doCanonicalize(): BatchScanExec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ case class ContinuousScanExec(
sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
partitions,
schema,
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory])
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory],
customMetrics)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition)
Expand All @@ -37,7 +38,8 @@ class DataSourceRDD(
sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory,
columnarReads: Boolean)
columnarReads: Boolean,
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
Expand All @@ -55,11 +57,13 @@ class DataSourceRDD(
val inputPartition = castPartition(split).inputPartition
val (iter, reader) = if (columnarReads) {
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader))
val iter = new MetricsBatchIterator(
new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
(iter, batchReader)
} else {
val rowReader = partitionReaderFactory.createReader(inputPartition)
val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader))
val iter = new MetricsRowIterator(
new PartitionIterator[InternalRow](rowReader, customMetrics))
(iter, rowReader)
}
context.addTaskCompletionListener[Unit](_ => reader.close())
Expand All @@ -72,7 +76,9 @@ class DataSourceRDD(
}
}

private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] {
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
private[this] var valuePrepared = false

override def hasNext: Boolean = {
Expand All @@ -86,6 +92,12 @@ private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
}
reader.currentMetricsValues.foreach { metric =>
assert(customMetrics.contains(metric.name()),
s"Custom metrics ${customMetrics.keys.mkString(", ")} do not contain the metric " +
s"${metric.name()}")
customMetrics(metric.name()).set(metric.value())
}
valuePrepared = false
reader.get()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ import org.apache.spark.util.Utils

trait DataSourceV2ScanExecBase extends LeafExecNode {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap

override lazy val metrics = {
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++
customMetrics
}

def scan: Scan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ case class MicroBatchScanExec(
override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.Locale

import org.apache.spark.sql.connector.CustomMetric

object CustomMetrics {
private[spark] val V2_CUSTOM = "v2Custom"

/**
* Given a class name, builds and returns a metric type for a V2 custom metric class
* `CustomMetric`.
*/
def buildV2CustomMetricTypeName(customMetric: CustomMetric): String = {
s"${V2_CUSTOM}_${customMetric.getClass.getCanonicalName}"
}

/**
* Given a V2 custom metric type name, this method parses it and returns the corresponding
* `CustomMetric` class name.
*/
def parseV2CustomMetricType(metricType: String): Option[String] = {
if (metricType.startsWith(s"${V2_CUSTOM}_")) {
Some(metricType.drop(V2_CUSTOM.length + 1))
} else {
None
}
}
}

/**
* Built-in `CustomMetric` that sums up metric values.
*/
class CustomSumMetric extends CustomMetric {
override def name(): String = "CustomSumMetric"

override def description(): String = "Sum up CustomMetric"

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
taskMetrics.sum.toString
}
}

/**
* Built-in `CustomMetric` that computes average of metric values.
*/
class CustomAvgMetric extends CustomMetric {
override def name(): String = "CustomAvgMetric"

override def description(): String = "Average CustomMetric"

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
val average = if (taskMetrics.isEmpty) {
0.0
} else {
taskMetrics.sum.toDouble / taskMetrics.length
}
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(average)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.concurrent.duration._

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.CustomMetric
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}

Expand Down Expand Up @@ -107,6 +108,15 @@ object SQLMetrics {
acc
}

/**
* Create a metric to report data source v2 custom metric.
*/
def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = {
val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric))
acc.register(sc, name = Some(customMetric.name()), countFailedValues = false)
acc
}

/**
* Create a metric to report the size information (including total, min, med, max) like data size,
* spill size, etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -52,7 +53,8 @@ class ContinuousDataSourceRDD(
epochPollIntervalMs: Long,
private val inputPartitions: Seq[InputPartition],
schema: StructType,
partitionReaderFactory: ContinuousPartitionReaderFactory)
partitionReaderFactory: ContinuousPartitionReaderFactory,
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
Expand Down Expand Up @@ -88,8 +90,12 @@ class ContinuousDataSourceRDD(
partition.queueReader
}

val partitionReader = readerForPartition.getPartitionReader()
new NextIterator[InternalRow] {
override def getNext(): InternalRow = {
partitionReader.currentMetricsValues.foreach { metric =>
customMetrics(metric.name()).set(metric.value())
}
readerForPartition.next() match {
case null =>
finished = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils
Expand All @@ -47,6 +48,8 @@ class ContinuousQueuedDataReader(
// Important sequencing - we must get our starting point before the provider threads start running
private var currentOffset: PartitionOffset = reader.getOffset

def getPartitionReader(): PartitionReader[InternalRow] = reader

/**
* The record types in the read buffer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Status._
import org.apache.spark.scheduler._
import org.apache.spark.sql.connector.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
Expand Down Expand Up @@ -199,7 +202,37 @@ class SQLAppStatusListener(
}

private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val accumIds = exec.metrics.map(_.accumulatorId).toSet

val metricAggregationMap = new mutable.HashMap[String, (Array[Long], Array[Long]) => String]()
val metricAggregationMethods = exec.metrics.map { m =>
val optClassName = CustomMetrics.parseV2CustomMetricType(m.metricType)
val metricAggMethod = optClassName.map { className =>
if (metricAggregationMap.contains(className)) {
metricAggregationMap(className)
} else {
// Try to initiate custom metric object
try {
val metric = Utils.loadExtensions(classOf[CustomMetric], Seq(className), conf).head
val method =
(metrics: Array[Long], _: Array[Long]) => metric.aggregateTaskMetrics(metrics)
metricAggregationMap.put(className, method)
method
} catch {
case NonFatal(_) =>
// Cannot initialize custom metric object, we might be in history server that does
// not have the custom metric class.
val defaultMethod = (_: Array[Long], _: Array[Long]) => "N/A"
metricAggregationMap.put(className, defaultMethod)
defaultMethod
}
}
}.getOrElse(
// Built-in SQLMetric
SQLMetrics.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap

val liveStageMetrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
Expand All @@ -212,7 +245,7 @@ class SQLAppStatusListener(

val maxMetricsFromAllStages = new mutable.HashMap[Long, Array[Long]]()

taskMetrics.filter(m => metricTypes.contains(m._1)).foreach { case (id, values) =>
taskMetrics.filter(m => accumIds.contains(m._1)).foreach { case (id, values) =>
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
prev ++ values
Expand All @@ -223,7 +256,7 @@ class SQLAppStatusListener(
}

// Find the max for each metric id between all stages.
val validMaxMetrics = maxMetrics.filter(m => metricTypes.contains(m._1))
val validMaxMetrics = maxMetrics.filter(m => accumIds.contains(m._1))
validMaxMetrics.foreach { case (id, value, taskId, stageId, attemptId) =>
val updated = maxMetricsFromAllStages.getOrElse(id, Array(value, stageId, attemptId, taskId))
if (value > updated(0)) {
Expand All @@ -236,7 +269,7 @@ class SQLAppStatusListener(
}

exec.driverAccumUpdates.foreach { case (id, value) =>
if (metricTypes.contains(id)) {
if (accumIds.contains(id)) {
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
// If the driver updates same metrics as tasks and has higher value then remove
Expand All @@ -256,7 +289,7 @@ class SQLAppStatusListener(
}

val aggregatedMetrics = allMetrics.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values, maxMetricsFromAllStages.getOrElse(id,
id -> metricAggregationMethods(id)(values, maxMetricsFromAllStages.getOrElse(id,
Array.empty[Long]))
}.toMap

Expand Down
Loading

0 comments on commit eb9a439

Please sign in to comment.