Skip to content

Commit

Permalink
Use JDK's Cleaner instead
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 24, 2024
1 parent 2516fd8 commit 8f7eace
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3492,7 +3492,7 @@ class Dataset[T] private[sql] (
.getOrElse(throw new RuntimeException("CheckpointCommandResult must be present"))

val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation
sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation)
sparkSession.cleaner.register(cachedRemoteRelation)

// Update the builder with the values from the result.
builder.setCachedRemoteRelation(cachedRemoteRelation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ class SparkSession private[sql] (
with Logging {

private[this] val allocator = new RootAllocator()
private var shouldStopCleaner = false
private[sql] lazy val cleaner = {
shouldStopCleaner = true
new SessionCleaner(this)
}
private[sql] lazy val cleaner = new SessionCleaner(this)

// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
Expand Down Expand Up @@ -719,9 +715,6 @@ class SparkSession private[sql] (
if (releaseSessionOnClose) {
client.releaseSession()
}
if (shouldStopCleaner) {
cleaner.stop()
}
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,130 +17,33 @@

package org.apache.spark.sql.internal

import java.lang.ref.{ReferenceQueue, WeakReference}
import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
import java.lang.ref.Cleaner

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)

/**
* An asynchronous cleaner for objects.
*
* This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed
* when the associated object goes out of scope of the application. Actual cleanup is performed in
* a separate daemon thread.
*/
private[sql] class SessionCleaner(session: SparkSession) extends Logging {

/**
* How often (seconds) to trigger a garbage collection in this JVM. This context cleaner
* triggers cleanups only when weak references are garbage collected. In long-running
* applications with large driver JVMs, where there is little memory pressure on the driver,
* this may happen very occasionally or not at all. Not cleaning at all may lead to executors
* running out of disk space after a while.
*/
private val refQueuePollTimeout: Long = 100

/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
* have not been handled by the reference queue.
*/
private val referenceBuffer =
Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)

private val referenceQueue = new ReferenceQueue[AnyRef]

private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() }

@volatile private var started = false
@volatile private var stopped = false

/** Start the cleaner. */
def start(): Unit = {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Connect Context Cleaner")
cleaningThread.start()
}

/**
* Stop the cleaning thread and wait until the thread has finished running its current task.
*/
def stop(): Unit = {
stopped = true
// Interrupt the cleaning thread, but wait until the current task has finished before
// doing so. This guards against the race condition where a cleaning thread may
// potentially clean similarly named variables created by a different SparkSession.
synchronized {
cleaningThread.interrupt()
}
cleaningThread.join()
}
private val cleaner = Cleaner.create()

/** Register a CachedRemoteRelation for cleanup when it is garbage collected. */
def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = {
registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
if (!started) {
// Lazily starts when the first cleanup is registered.
start()
started = true
}
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
def register(relation: proto.CachedRemoteRelation): Unit = {
val dfID = relation.getRelationId
cleaner.register(relation, () => doCleanupCachedRemoteRelation(dfID))
}

/** Keep cleaning objects. */
private def keepCleaning(): Unit = {
while (!stopped && !session.client.channel.isShutdown) {
try {
val reference = Option(referenceQueue.remove(refQueuePollTimeout))
.map(_.asInstanceOf[CleanupTaskWeakReference])
// Synchronize here to avoid being interrupted on stop()
synchronized {
reference.foreach { ref =>
logDebug("Got cleaning task " + ref.task)
referenceBuffer.remove(ref)
ref.task match {
case CleanupCachedRemoteRelation(dfID) =>
doCleanupCachedRemoteRelation(dfID)
}
private[sql] def doCleanupCachedRemoteRelation(dfID: String): Unit = {
try {
if (!session.client.channel.isShutdown) {
session.execute {
session.newCommand { builder =>
builder.getRemoveCachedRemoteRelationCommandBuilder
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build())
}
}
} catch {
case e: Throwable => logError("Error in cleaning thread", e)
}
}
}

/** Perform CleanupCachedRemoteRelation cleanup. */
private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = {
session.execute {
session.newCommand { builder =>
builder.getRemoveCachedRemoteRelationCommandBuilder
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build())
}
} catch {
case e: Throwable => logError("Error in cleaning thread", e)
}
}
}

0 comments on commit 8f7eace

Please sign in to comment.