diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileMain.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileMain.scala index f2c4ba6a8..b86b49525 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileMain.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileMain.scala @@ -30,7 +30,7 @@ object ProfileMain extends Logging { * Entry point from spark-submit running this as the driver. */ def main(args: Array[String]) { - val (exitCode, _) = mainInternal(new ProfileArgs(args)) + val (exitCode, _) = mainInternal(new ProfileArgs(args), enablePB = true) if (exitCode != 0) { System.exit(exitCode) } @@ -39,7 +39,7 @@ object ProfileMain extends Logging { /** * Entry point for tests */ - def mainInternal(appArgs: ProfileArgs): (Int, Int) = { + def mainInternal(appArgs: ProfileArgs, enablePB: Boolean = false): (Int, Int) = { // Parsing args val eventlogPaths = appArgs.eventlog() @@ -67,7 +67,7 @@ object ProfileMain extends Logging { return (0, filteredLogs.size) } - val profiler = new Profiler(hadoopConf, appArgs) + val profiler = new Profiler(hadoopConf, appArgs, enablePB) profiler.profile(eventLogFsFiltered) (0, filteredLogs.size) } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala index ca285eaf9..b61b58a97 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala @@ -28,8 +28,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo +import org.apache.spark.sql.rapids.tool.ui.ConsoleProgressBar -class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging { +class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolean) extends Logging { private val nThreads = appArgs.numThreads.getOrElse( Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) @@ -48,6 +49,7 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging private val outputCombined: Boolean = appArgs.combined() private val useAutoTuner: Boolean = appArgs.autoTuner() + private var progressBar: Option[ConsoleProgressBar] = None logInfo(s"Threadpool size is $nThreads") @@ -58,6 +60,9 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging * what else we can do in parallel. */ def profile(eventLogInfos: Seq[EventLogInfo]): Unit = { + if (enablePB && eventLogInfos.nonEmpty) { // total count to start the PB cannot be 0 + progressBar = Some(new ConsoleProgressBar("Profile Tool", eventLogInfos.length)) + } if (appArgs.compare()) { if (outputCombined) { logError("Output combined option not valid with compare mode!") @@ -69,6 +74,7 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging val apps = createApps(eventLogInfos) if (apps.size < 2) { + progressBar.foreach(_.reportUnknownStatusProcesses(apps.size)) logError("At least 2 applications are required for comparison mode. Exiting!") } else { val profileOutputWriter = new ProfileOutputWriter(s"$outputDir/compare", @@ -76,9 +82,12 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging try { // we need the info for all of the apps to be able to compare so this happens serially val (sums, comparedRes) = processApps(apps, printPlans = false, profileOutputWriter) - writeOutput(profileOutputWriter, Seq(sums), false, comparedRes) - } - finally { + progressBar.foreach(_.reportSuccessfulProcesses(apps.size)) + writeSafelyToOutput(profileOutputWriter, Seq(sums), false, comparedRes) + } catch { + case _: Exception => + progressBar.foreach(_.reportFailedProcesses(apps.size)) + } finally { profileOutputWriter.close() } } @@ -94,11 +103,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging val profileOutputWriter = new ProfileOutputWriter(s"$outputDir/combined", Profiler.COMBINED_LOG_FILE_NAME_PREFIX, numOutputRows, outputCSV = outputCSV) val sums = createAppsAndSummarize(eventLogInfos, false, profileOutputWriter) - try { - writeOutput(profileOutputWriter, sums, outputCombined) - } finally { - profileOutputWriter.close() - } + writeSafelyToOutput(profileOutputWriter, sums, outputCombined) + profileOutputWriter.close() } } else { // Read each application and process it separately to save memory. @@ -115,6 +121,7 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging threadPool.shutdownNow() } } + progressBar.foreach(_.finishAll()) } private def errorHandler(error: Throwable, path: EventLogInfo) = { @@ -139,9 +146,16 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging def run: Unit = { try { val appOpt = createApp(path, index, hadoopConf) - appOpt.foreach(app => allApps.add(app)) + appOpt match { + case Some(app) => + allApps.add(app) + case None => + progressBar.foreach(_.reportUnkownStatusProcess()) + } } catch { - case t: Throwable => errorHandler(t, path) + case t: Throwable => + progressBar.foreach(_.reportFailedProcess()) + errorHandler(t, path) } } } @@ -176,19 +190,24 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging def run: Unit = { try { val appOpt = createApp(path, index, hadoopConf) - appOpt.foreach { app => - val sum = try { - val (s, _) = processApps(Seq(app), false, profileOutputWriter) - Some(s) - } catch { - case e: Exception => - logWarning(s"Unexpected exception thrown ${path.eventLog.toString}, skipping! ", e) - None - } - sum.foreach(allApps.add(_)) + appOpt match { + case Some(app) => + try { + val (s, _) = processApps(Seq(app), false, profileOutputWriter) + progressBar.foreach(_.reportSuccessfulProcess()) + allApps.add(s) + } catch { + case t: Throwable => + progressBar.foreach(_.reportFailedProcess()) + errorHandler(t, path) + } + case None => + progressBar.foreach(_.reportUnkownStatusProcess()) } } catch { - case t: Throwable => errorHandler(t, path) + case t: Throwable => + progressBar.foreach(_.reportFailedProcess()) + errorHandler(t, path) } } } @@ -228,15 +247,23 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging try { val (sum, _) = processApps(Seq(appOpt.get), appArgs.printPlans(), profileOutputWriter) - writeOutput(profileOutputWriter, Seq(sum), false) + progressBar.foreach(_.reportSuccessfulProcess()) + writeSafelyToOutput(profileOutputWriter, Seq(sum), false) + } catch { + case t: Throwable => + progressBar.foreach(_.reportFailedProcess()) + errorHandler(t, path) } finally { profileOutputWriter.close() } case None => + progressBar.foreach(_.reportUnkownStatusProcess()) logInfo("No application to process. Exiting") } } catch { - case t: Throwable => errorHandler(t, path) + case t: Throwable => + progressBar.foreach(_.reportFailedProcess()) + errorHandler(t, path) } } } @@ -482,6 +509,22 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs) extends Logging } } } + + /** + * Safely writes the application summary information to the specified profileOutputWriter. + * If an exception occurs during the writing process, it will be caught and logged, preventing + * it from propagating further. + */ + private def writeSafelyToOutput(profileOutputWriter: ProfileOutputWriter, + appsSum: Seq[ApplicationSummaryInfo], outputCombined: Boolean, + comparedRes: Option[CompareSummaryInfo] = None): Unit = { + try { + writeOutput(profileOutputWriter, appsSum, outputCombined, comparedRes) + } catch { + case e: Exception => + logError("Exception thrown while writing", e) + } + } } object Profiler { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ui/ConsoleProgressBar.scala index a55b297dc..518c023b0 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ui/ConsoleProgressBar.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -132,6 +132,18 @@ class ConsoleProgressBar( totalCounter.incrementAndGet() } + def reportSuccessfulProcesses(n: Int): Unit = { + (1 to n).foreach(_ => reportSuccessfulProcess()) + } + + def reportFailedProcesses(n: Int): Unit = { + (1 to n).foreach(_ => reportFailedProcess()) + } + + def reportUnknownStatusProcesses(n: Int): Unit = { + (1 to n).foreach(_ => reportUnkownStatusProcess()) + } + def metricsToString: String = { val sb = new mutable.StringBuilder() metrics.foreach { case (name, counter) => @@ -193,6 +205,8 @@ class ConsoleProgressBar( /** * Mark all processing as finished. + * TODO: All processes that have not been finished (totalCount - currentCount) + * should be marked as unknown. */ def finishAll(): Unit = synchronized { stop()