Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34338][SQL] Report metrics from Datasource v2 scan #31451

Closed
wants to merge 16 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public interface PartitionReader<T> extends Closeable {
T get();

/**
* Returns an array of custom task metrics. By default it returns empty array.
* Returns an array of custom task metrics. By default it returns empty array. Note that it is
* not recommended to put heavy logic in this method as it may affect reading performance.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
*/
default CustomTaskMetric[] currentMetricsValues() {
CustomTaskMetric[] NO_METRICS = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class BatchScanExec(
override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
}

override def doCanonicalize(): BatchScanExec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ case class ContinuousScanExec(
sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
partitions,
schema,
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory])
readerFactory.asInstanceOf[ContinuousPartitionReaderFactory],
customMetrics)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

class DataSourceRDDPartition(val index: Int, val inputPartition: InputPartition)
Expand All @@ -37,7 +38,8 @@ class DataSourceRDD(
sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
partitionReaderFactory: PartitionReaderFactory,
columnarReads: Boolean)
columnarReads: Boolean,
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
Expand All @@ -55,11 +57,13 @@ class DataSourceRDD(
val inputPartition = castPartition(split).inputPartition
val (iter, reader) = if (columnarReads) {
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
val iter = new MetricsBatchIterator(new PartitionIterator[ColumnarBatch](batchReader))
val iter = new MetricsBatchIterator(
new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
(iter, batchReader)
} else {
val rowReader = partitionReaderFactory.createReader(inputPartition)
val iter = new MetricsRowIterator(new PartitionIterator[InternalRow](rowReader))
val iter = new MetricsRowIterator(
new PartitionIterator[InternalRow](rowReader, customMetrics))
(iter, rowReader)
}
context.addTaskCompletionListener[Unit](_ => reader.close())
Expand All @@ -72,7 +76,9 @@ class DataSourceRDD(
}
}

private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[T] {
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
private[this] var valuePrepared = false

override def hasNext: Boolean = {
Expand All @@ -86,6 +92,9 @@ private class PartitionIterator[T](reader: PartitionReader[T]) extends Iterator[
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
}
reader.currentMetricsValues.foreach { metric =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's still per-row update?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, will do "update metrics per some rows" in a followup.

customMetrics(metric.name()).set(metric.value())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check whether customMetrics contains metric.name()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can throw a user-friendly message if it doesn't.

}
valuePrepared = false
reader.get()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ import org.apache.spark.util.Utils

trait DataSourceV2ScanExecBase extends LeafExecNode {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
}.toMap

override lazy val metrics = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plans to similarly support metrics in writes?

Copy link
Member Author

@viirya viirya Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at a few V2 write nodes, but seems we don't have any SQL metrics there (even number of output rows). I guess we don't provide metrics for writes generally now?

If there is interest to see metrics in writes, I think it is okay to work on it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like updating context.taskMetrics().outputMetrics is just in our branch. That just uses the Hadoop FS metrics collection that we use elsewhere, so it isn't metrics from the source as we want to support in this PR.

I think it would be good to follow up and support metrics on the output side. It doesn't need to be done here, but metrics are really useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. I can work on it in follow up PRs.

Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) ++
customMetrics
}

def scan: Scan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ case class MicroBatchScanExec(
override lazy val readerFactory: PartitionReaderFactory = stream.createReaderFactory()

override lazy val inputRDD: RDD[InternalRow] = {
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar, customMetrics)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.metric

import org.apache.spark.sql.connector.CustomMetric

object CustomMetrics {
private[spark] val V2_CUSTOM = "v2Custom"

/**
* Given a class name, builds and returns a metric type for a V2 custom metric class
* `CustomMetric`.
*/
def buildV2CustomMetricTypeName(customMetric: CustomMetric): String = {
s"${V2_CUSTOM}_${customMetric.getClass.getCanonicalName}"
}

/**
* Given a V2 custom metric type name, this method parses it and returns the corresponding
* `CustomMetric` class name.
*/
def parseV2CustomMetricType(metricType: String): Option[String] = {
val className = metricType.stripPrefix(s"${V2_CUSTOM}_")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to make the code more readable

if (metricType.startsWith(s"${V2_CUSTOM}_")) {
  Some(metricType.drop(V2_CUSTOM.length + 1))
} else {
  None
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


if (className == metricType) {
None
} else {
Some(className)
}
}
}

/**
* Built-in `CustomMetric` that sums up metric values.
*/
class CustomSumMetric extends CustomMetric {
override def name(): String = "CustomSumMetric"

override def description(): String = "Sum up CustomMetric"

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
taskMetrics.sum.toString
}
}

/**
* Built-in `CustomMetric` that computes average of metric values.
*/
class CustomAvgMetric extends CustomMetric {
override def name(): String = "CustomAvgMetric"

override def description(): String = "Average CustomMetric"

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
val average = if (taskMetrics.isEmpty) {
0L
} else {
taskMetrics.sum / taskMetrics.length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why choosing the Long type instead of the Double type here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No special reason. We can use Double to show numbers after decimal point.

}
average.toString
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.concurrent.duration._

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.CustomMetric
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}

Expand Down Expand Up @@ -107,6 +108,15 @@ object SQLMetrics {
acc
}

/**
* Create a metric to report data source v2 custom metric.
*/
def createV2CustomMetric(sc: SparkContext, customMetric: CustomMetric): SQLMetric = {
val acc = new SQLMetric(CustomMetrics.buildV2CustomMetricTypeName(customMetric))
acc.register(sc, name = Some(customMetric.name()), countFailedValues = false)
acc
}

/**
* Create a metric to report the size information (including total, min, med, max) like data size,
* spill size, etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.ContinuousPartitionReaderFactory
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -52,7 +53,8 @@ class ContinuousDataSourceRDD(
epochPollIntervalMs: Long,
private val inputPartitions: Seq[InputPartition],
schema: StructType,
partitionReaderFactory: ContinuousPartitionReaderFactory)
partitionReaderFactory: ContinuousPartitionReaderFactory,
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
Expand Down Expand Up @@ -88,8 +90,12 @@ class ContinuousDataSourceRDD(
partition.queueReader
}

val partitionReader = readerForPartition.getPartitionReader()
new NextIterator[InternalRow] {
override def getNext(): InternalRow = {
partitionReader.currentMetricsValues.foreach { metric =>
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
customMetrics(metric.name()).set(metric.value())
}
readerForPartition.next() match {
case null =>
finished = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, PartitionOffset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils
Expand All @@ -47,6 +48,8 @@ class ContinuousQueuedDataReader(
// Important sequencing - we must get our starting point before the provider threads start running
private var currentOffset: PartitionOffset = reader.getOffset

def getPartitionReader(): PartitionReader[InternalRow] = reader

/**
* The record types in the read buffer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Status._
import org.apache.spark.scheduler._
import org.apache.spark.sql.connector.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
Expand Down Expand Up @@ -199,7 +202,37 @@ class SQLAppStatusListener(
}

private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val accumIds = exec.metrics.map(_.accumulatorId).toSet

val metricAggregationMap = new mutable.HashMap[String, (Array[Long], Array[Long]) => String]()
val metricAggregationMethods = exec.metrics.map { m =>
val optClassName = CustomMetrics.parseV2CustomMetricType(m.metricType)
val metricAggMethod = optClassName.map { className =>
if (metricAggregationMap.contains(className)) {
metricAggregationMap(className)
} else {
// Try to initiate custom metric object
try {
val metric = Utils.loadExtensions(classOf[CustomMetric], Seq(className), conf).head
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @viirya, the loadExtensions requires that class has a no-arg constructor, or a 1-arg constructor that takes a SparkConf, do we really need this restriction for CustomMetric?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For current usage, I don't see there are other necessary args for CustomMetric constructor. Besides, as an interface, we cannot define constructor for CustomMetric. I think CustomMetric should be pretty simple class and I'd like to keep it as simple as possible. If there are more usages that require additional args, we can add an initialization API used to pass some parameters. But I doubt that if it is necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should keep it as simple as possible. I'm asking because the simple case class does not work since it doesn't have a no-arg constructor, or a 1-arg constructor.

case class SumMetric(name: String, description: String) extends CustomSumMetric

It took me a while to dig out because it just doesn't work and no WARN/ERROR logs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added a warning log at #37386

val method =
(metrics: Array[Long], _: Array[Long]) => metric.aggregateTaskMetrics(metrics)
metricAggregationMap.put(className, method)
method
} catch {
case NonFatal(_) =>
// Cannot initiaize custom metric object, we might be in history server that does
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: initialize

// not have the custom metric class.
val defaultMethod = (_: Array[Long], _: Array[Long]) => "N/A"
metricAggregationMap.put(className, defaultMethod)
defaultMethod
}
}
}.getOrElse(
// Built-in SQLMetric
SQLMetrics.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap

val liveStageMetrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
Expand All @@ -212,7 +245,7 @@ class SQLAppStatusListener(

val maxMetricsFromAllStages = new mutable.HashMap[Long, Array[Long]]()

taskMetrics.filter(m => metricTypes.contains(m._1)).foreach { case (id, values) =>
taskMetrics.filter(m => accumIds.contains(m._1)).foreach { case (id, values) =>
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
prev ++ values
Expand All @@ -223,7 +256,7 @@ class SQLAppStatusListener(
}

// Find the max for each metric id between all stages.
val validMaxMetrics = maxMetrics.filter(m => metricTypes.contains(m._1))
val validMaxMetrics = maxMetrics.filter(m => accumIds.contains(m._1))
validMaxMetrics.foreach { case (id, value, taskId, stageId, attemptId) =>
val updated = maxMetricsFromAllStages.getOrElse(id, Array(value, stageId, attemptId, taskId))
if (value > updated(0)) {
Expand All @@ -236,7 +269,7 @@ class SQLAppStatusListener(
}

exec.driverAccumUpdates.foreach { case (id, value) =>
if (metricTypes.contains(id)) {
if (accumIds.contains(id)) {
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
// If the driver updates same metrics as tasks and has higher value then remove
Expand All @@ -256,7 +289,7 @@ class SQLAppStatusListener(
}

val aggregatedMetrics = allMetrics.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values, maxMetricsFromAllStages.getOrElse(id,
id -> metricAggregationMethods(id)(values, maxMetricsFromAllStages.getOrElse(id,
Array.empty[Long]))
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
}.toMap

Expand Down
Loading