diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 78d05efb0c2..17ea163b1af 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -22,7 +22,8 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{NvtxColor, NvtxRange} +import ai.rapids.cudf.{NvtxColor, NvtxRange, NvtxUniqueRange} +import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import org.apache.spark.TaskContext @@ -179,6 +180,7 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { */ private val activeThreads = new util.LinkedHashSet[Thread]() private lazy val numPermits = GpuSemaphore.computeNumPermits(SQLConf.get) + private lazy val trackSemaphore = RapidsConf.TRACE_TASK_GPU_OWNERSHIP.get(SQLConf.get) /** * If this task holds the GPU semaphore or not. */ @@ -187,6 +189,8 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { type GpuBackingSemaphore = PrioritySemaphore[Long] + var nvtxRange: Option[NvtxUniqueRange] = None + /** * Does this task have the GPU semaphore or not. Be careful because it can change at * any point in time. So only use it for logging. @@ -258,6 +262,10 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { // We now own the semaphore so we need to wake up all of the other tasks that are // waiting. hasSemaphore = true + if (trackSemaphore) { + nvtxRange = + Some(new NvtxUniqueRange(s"Sem-${taskAttemptId}", NvtxColor.ORANGE)) + } moveToActive(t) notifyAll() done = true @@ -309,6 +317,10 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { semaphore.release(numPermits) hasSemaphore = false lastHeld = System.currentTimeMillis() + nvtxRange match { + case Some(range) => range.safeClose() + case _ => // do nothing + } } // It should be impossible for the current thread to be blocked when releasing the semaphore // because no blocked thread should ever leave `blockUntilReady`, which is where we put it in @@ -325,6 +337,7 @@ private final class GpuSemaphore() extends Logging { type GpuBackingSemaphore = PrioritySemaphore[Long] private val semaphore = new GpuBackingSemaphore(MAX_PERMITS) // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU + // taskAttemptId => semaphoreTaskInfo private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo] def tryAcquire(context: TaskContext): TryAcquireResult = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 50dc457268c..2147773a4de 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2383,6 +2383,13 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .booleanConf .createWithDefault(true) + val TRACE_TASK_GPU_OWNERSHIP = conf("spark.rapids.sql.traceTaskGpuOwnership") + .doc("Enable tracing of the GPU ownership of tasks. This can be useful for debugging " + + "deadlocks and other issues related to GPU semaphore.") + .internal() + .booleanConf + .createWithDefault(false) + private def printSectionHeader(category: String): Unit = println(s"\n### $category")