Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TASK] Optimize the storage of accumulables in core tools #1263

Merged
merged 16 commits into from
Aug 8, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.nvidia.spark.rapids.tool.qualification.QualSQLPlanAnalyzer

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SQLMetricsStats, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
Expand Down Expand Up @@ -59,6 +59,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
/**
* Connects Operators to Stages using AccumulatorIDs.
* TODO: This function can be fused in the visitNode function to avoid the extra iteration.
*
* @param cb function that creates a SparkPlanGraph. This can be used as a cacheHolder for the
* object created to be used later.
*/
Expand All @@ -80,7 +81,8 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* Both Qual/Prof analysis.
* For Qual apps, the app.sqlIDtoProblematic won't be set because it is done later during the
* aggregation phase.
* @param sqlId the SQL ID being analyzed
*
* @param sqlId the SQL ID being analyzed
* @param potentialProblems a set of strings that represent the potential problems found in the
* SQL plan.
*/
Expand Down Expand Up @@ -112,26 +114,26 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* 1- allSQLMetrics: a list of SQLMetricInfoCase
* 2- wholeStage: a list of WholeStageCodeGenResults
* 3- unsupportedSQLPlan: a list of UnsupportedSQLPlan that contains the SQL ID, node ID,
* node name.
* TODO: Consider handling the construction of this list in a different way for the
* Qualification
* node name.
* TODO: Consider handling the construction of this list in a different way for the
* Qualification
* 4- sqlPlanNodeIdToStageIds: A map between (SQL ID, Node ID) and the set of stage IDs
*
* It has the following effect on the visitor object:
* 1- It updates the sqlIsDsOrRDD argument to True when the visited node is an RDD or Dataset.
* 2- If the SQLID is an RDD, the potentialProblems is cleared because once SQL is marked as RDD,
* all the other problems are ignored. Note that we need to set that flag only once to True
* for the given SQLID.
* all the other problems are ignored. Note that we need to set that flag only once to True
* for the given SQLID.
* 3- It appends the current node's potential problems to the SQLID problems only if the SQL is
* visitor.sqlIsDsOrRDD is False. Otherwise, it is kind of redundant to keep checking for
* potential problems for every node when they get to be ignored.
* visitor.sqlIsDsOrRDD is False. Otherwise, it is kind of redundant to keep checking for
* potential problems for every node when they get to be ignored.
*
* It has the following effect on the app object:
* 1- it updates dataSourceInfo with V2 and V1 data sources
* 2- it updates sqlIDtoProblematic the map between SQL ID and potential problems
*
*
* @param visitor the visitor context defined per SQLPlan
* @param node the node being currently visited.
* @param node the node being currently visited.
*/
protected def visitNode(visitor: SQLPlanVisitorContext,
node: SparkPlanGraphNode): Unit = {
Expand All @@ -154,7 +156,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
if (nodeIsDsOrRDD) {
// we want to report every node that is an RDD
val thisPlan = UnsupportedSQLPlan(visitor.sqlPIGEntry.sqlID, node.id, node.name, node.desc,
"Contains Dataset or RDD")
"Contains Dataset or RDD")
unsupportedSQLPlan += thisPlan
// If one node is RDD, the Sql should be set too
if (!visitor.sqlIsDsOrRDD) { // We need to set the flag only once for the given sqlID
Expand Down Expand Up @@ -265,46 +267,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
sqlToStages.toSeq
}

// Store (min, median, max, total) for a given metric
private case class StatisticsMetrics(min: Long, med:Long, max:Long, total: Long)

def generateSQLAccums(): Seq[SQLAccumProfileResults] = {
allSQLMetrics.flatMap { metric =>
val jobsForSql = app.jobIdToInfo.filter { case (_, jc) =>
// Avoid getOrElse to reduce memory allocations
jc.sqlID.isDefined && jc.sqlID.get == metric.sqlID
}
val stageIdsForSQL = jobsForSql.flatMap(_._2.stageIds).toSet
val accumsOpt = app.taskStageAccumMap.get(metric.accumulatorId)
val taskMax = accumsOpt match {
case Some(accums) =>
val filtered = accums.filter { a =>
stageIdsForSQL.contains(a.stageId)
}
// If metricType is size, average or timing, we want to read field `update` value
// to get the min, median, max, and total. Otherwise, we want to use field `value`.
if (SQLMetricsStats.hasStats(metric.metricType)) {
val accumValues = filtered.map(_.update.getOrElse(0L)).sortWith(_ < _)
if (accumValues.isEmpty) {
None
}
else if (accumValues.length <= 1) {
Some(StatisticsMetrics(0L, 0L, 0L, accumValues.sum))
} else {
Some(StatisticsMetrics(accumValues(0), accumValues(accumValues.size / 2),
accumValues(accumValues.size - 1), accumValues.sum))
}
} else {
val accumValues = filtered.map(_.value.getOrElse(0L))
if (accumValues.isEmpty) {
None
} else {
Some(StatisticsMetrics(0L, 0L, 0L, accumValues.max))
}
}
case None => None
}

val accumTaskStats = app.accumManager.calculateAccStats(metric.accumulatorId)
// local mode driver gets updates
val driverAccumsOpt = app.driverAccumMap.get(metric.accumulatorId)
val driverMax = driverAccumsOpt match {
Expand All @@ -325,9 +290,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
None
}

if (taskMax.isDefined || driverMax.isDefined) {
val taskInfo = taskMax.getOrElse(StatisticsMetrics(0L, 0L, 0L, 0L))
val driverInfo = driverMax.getOrElse(StatisticsMetrics(0L, 0L, 0L, 0L))
if (accumTaskStats.isDefined || driverMax.isDefined) {
val taskInfo = accumTaskStats.getOrElse(StatisticsMetrics.ZERO_RECORD)
val driverInfo = driverMax.getOrElse(StatisticsMetrics.ZERO_RECORD)

val max = Math.max(taskInfo.max, driverInfo.max)
val min = Math.max(taskInfo.min, driverInfo.min)
Expand All @@ -354,47 +319,37 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* @return a sequence of AccumProfileResults
*/
def generateStageLevelAccums(): Seq[AccumProfileResults] = {

def computeStatistics(updates: Seq[Long]): Option[StatisticsMetrics] = {
// drop the metrics if there are no values
if (updates.isEmpty) {
None
} else if (updates.length == 1) {
Some(StatisticsMetrics(0L, 0L, 0L, updates.sum))
} else {
Some(StatisticsMetrics(
min = updates.head,
med = updates(updates.size / 2),
max = updates.last,
total = updates.sum
))
}
app.accumManager.accumInfoMap.flatMap { accumMapEntry =>
val accumInfo = accumMapEntry._2
accumInfo.stageValuesMap.keySet.flatMap( stageId => {
val stageTaskIds = app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet
// get the task updates that belong to that stage
val taskUpatesSubset =
accumInfo.taskUpdatesMap.filterKeys(stageTaskIds.contains).values.toSeq.sorted
if (taskUpatesSubset.isEmpty) {
None
} else {
val min = taskUpatesSubset.head
val max = taskUpatesSubset.last
val sum = taskUpatesSubset.sum
val median = if (taskUpatesSubset.size % 2 == 0) {
val mid = taskUpatesSubset.size / 2
(taskUpatesSubset(mid) + taskUpatesSubset(mid - 1)) / 2
} else {
taskUpatesSubset(taskUpatesSubset.size / 2)
}
Some(AccumProfileResults(
appIndex,
stageId,
accumInfo.infoRef,
min = min,
median = median,
max = max,
total = sum))
}
})
}

// Process taskStageAccumMap to get all the accumulators
val stageLevelAccums = app.taskStageAccumMap.values.flatten
val groupedByStageAndAccumulatorId = stageLevelAccums.groupBy(
task => (task.stageId, task.accumulatorId))
// Extract and sort the update values, defaulting to 0 if not present
groupedByStageAndAccumulatorId.flatMap { case ((stageId, accumulatorId), accums) =>
val sortedUpdates = accums.flatMap(_.update).toSeq.sorted

// Compute the statistics for the accumulator if applicable
computeStatistics(sortedUpdates).map { stats =>
val sampleAccum = accums.head
AccumProfileResults(
appIndex = appIndex,
stageId = stageId,
accumulatorId = accumulatorId,
name = sampleAccum.name.getOrElse("Unknown"),
min = stats.min,
median = stats.med,
max = stats.max,
total = stats.total
)
}
}.toSeq
}
}.toSeq
}

object AppSQLPlanAnalyzer {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.tool.analysis

// Store (min, median, max, total) for a given metric
case class StatisticsMetrics(min: Long, med:Long, max:Long, total: Long)
cindyyuanjiang marked this conversation as resolved.
Show resolved Hide resolved

object StatisticsMetrics {
// a static variable used to represent zero-statistics instead of allocating a dummy record
// on every calculation.
val ZERO_RECORD: StatisticsMetrics = StatisticsMetrics(0L, 0L, 0L, 0L)
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ case class ExecInfo(
object ExecInfo {
// Used to create an execInfo without recalculating the dataSet or Udf.
// This is helpful when we know that node description may contain some patterns that can be
// mistakenly identified as UDFs
// mistakenly identified as UDFs
def createExecNoNode(sqlID: Long,
exec: String,
expr: String,
Expand Down Expand Up @@ -443,9 +443,7 @@ object SQLPlanParser extends Logging {

def getStagesInSQLNode(node: SparkPlanGraphNode, app: AppBase): Set[Int] = {
val nodeAccums = node.metrics.map(_.accumulatorId)
nodeAccums.flatMap { nodeAccumId =>
app.stageManager.getStagesIdsByAccumId(nodeAccumId)
}.toSet
nodeAccums.flatMap(app.accumManager.getAccStageIds).toSet
}

// Set containing execs that refers to other expressions. We need this to be a list to allow
Expand Down Expand Up @@ -613,15 +611,10 @@ object SQLPlanParser extends Logging {
* the duration.
*/
def getTotalDuration(accumId: Option[Long], app: AppBase): Option[Long] = {
val taskForAccum = accumId.flatMap(id => app.taskStageAccumMap.get(id))
.getOrElse(ArrayBuffer.empty)
val accumValues = taskForAccum.map(_.value.getOrElse(0L))
val maxDuration = if (accumValues.isEmpty) {
None
} else {
Some(accumValues.max)
accumId match {
case Some(x) => app.accumManager.getMaxStageValue(x)
case _ => None
}
maxDuration
}

def getDriverTotalDuration(accumId: Option[Long], app: AppBase): Option[Long] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CompareApplications(apps: Seq[ApplicationInfo]) extends Logging {
val normalizedByAppId = apps.map { app =>
val normalized = app.sqlPlans.mapValues { plan =>
SparkPlanInfoWithStage(plan,
app.stageManager.getAccumToSingleStage()).normalizeForStageComparison
app.accumManager.getAccumSingleStage).normalizeForStageComparison
}
(app.appId, normalized)
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object GenerateDot {
val accumSummary = accums.map { a =>
Seq(a.sqlID, a.accumulatorId, a.total)
}
val accumIdToStageId = app.stageManager.getAccumToSingleStage()
val accumIdToStageId = app.accumManager.getAccumSingleStage
val formatter = java.text.NumberFormat.getIntegerInstance
val stageIdToStageMetrics = app.taskManager.stageAttemptToTasks.collect { case (stageId, _) =>
val tasks = app.taskManager.getAllTasksStageAttempt(stageId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import com.nvidia.spark.rapids.tool.ToolTextFileWriter

import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.store.{AccNameRef, AccumInfo}

abstract class TimelineTiming(
val startTime: Long,
Expand Down Expand Up @@ -273,27 +274,22 @@ object GenerateTimeline {
}
}

val semMetricsNs = semWaitIds.toList.flatMap { id =>
app.taskStageAccumMap.get(id)
}.flatten
val semMetricsNs = semWaitIds.toList
.flatMap(app.accumManager.accumInfoMap.get)
.flatMap(_.taskUpdatesMap.values).sum

val semMetricsMs = app.taskStageAccumMap.values.filter { buffer =>
buffer.headOption.exists { h =>
h.name.exists(_ == "gpuSemaphoreWait")
}
}.flatten
val semMetricsMs = app.accumManager.accumInfoMap.flatMap{
case (_,accumInfo: AccumInfo)
if accumInfo.infoRef.name == AccNameRef.NAMES_TABLE.get("gpuSemaphoreWait") =>
Some(accumInfo.taskUpdatesMap.values.sum)
case _ => None
}.sum

val readMetrics = readTimeIds.toList.flatMap { id =>
app.taskStageAccumMap.get(id)
}.flatten
val readMetrics = readTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

val opMetrics = opTimeIds.toList.flatMap { id =>
app.taskStageAccumMap.get(id)
}.flatten
val opMetrics = opTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

val writeMetrics = writeTimeIds.toList.flatMap { id =>
app.taskStageAccumMap.get(id)
}.flatten
val writeMetrics = writeTimeIds.toList.flatMap(app.accumManager.accumInfoMap.get)

app.taskManager.getAllTasks().foreach { tc =>
val host = tc.host
Expand All @@ -303,20 +299,12 @@ object GenerateTimeline {
val launchTime = tc.launchTime
val finishTime = tc.finishTime
val duration = tc.duration
val semTimeMs = (semMetricsNs.filter { m =>
m.stageId == stageId && m.taskId.contains(taskId) && m.update.isDefined
}.flatMap(_.update).sum / 1000000) + (semMetricsMs.filter{ m =>
m.stageId == stageId && m.taskId.contains(taskId) && m.update.isDefined
}.flatMap(_.update).sum)
val readTimeMs = readMetrics.filter { m =>
m.stageId == stageId && m.taskId.contains(taskId) && m.update.isDefined
}.flatMap(_.update).sum / 1000000 + tc.sr_fetchWaitTime
val opTimeMs = opMetrics.filter { m =>
m.stageId == stageId && m.taskId.contains(taskId) && m.update.isDefined
}.flatMap(_.update).sum / 1000000
val writeTimeMs = writeMetrics.filter { m =>
m.stageId == stageId && m.taskId.contains(taskId) && m.update.isDefined
}.flatMap(_.update).sum / 1000000 + tc.sw_writeTime
val semTimeMs = ( semMetricsNs / 1000000) + semMetricsMs
val readTimeMs = readMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000 +
tc.sr_fetchWaitTime
val opTimeMs = opMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000
val writeTimeMs = writeMetrics.flatMap(_.taskUpdatesMap.get(taskId)).sum / 1000000 +
tc.sw_writeTime
val taskInfo = new TimelineTaskInfo(stageId, taskId, launchTime, finishTime, duration,
tc.executorDeserializeTime, readTimeMs, semTimeMs, opTimeMs, writeTimeMs)
val execHost = s"$execId/$host"
Expand Down
Loading