Skip to content

Commit

Permalink
Uses weakreferences and cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 22, 2024
1 parent 124d396 commit 9e6faa2
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkClassUtils

// scalastyle:off no.finalize
/**
* A Dataset is a strongly typed collection of domain-specific objects that can be transformed in
* parallel using functional or relational operations. Each Dataset also has an untyped view
Expand Down Expand Up @@ -3481,7 +3480,7 @@ class Dataset[T] private[sql] (
* false creates a local checkpoint using the caching subsystem
*/
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
val df = sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
val command = sparkSession.newCommand { builder =>
builder.getCheckpointCommandBuilder
.setLocal(reliableCheckpoint)
Expand All @@ -3493,15 +3492,17 @@ class Dataset[T] private[sql] (
val response = responseIter
.find(_.hasCheckpointCommandResult)
.getOrElse(throw new RuntimeException("CheckpointCommandResult must be present"))

val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation
sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation)

// Update the builder with the values from the result.
builder.setCachedRemoteRelation(response.getCheckpointCommandResult.getRelation)
builder.setCachedRemoteRelation(cachedRemoteRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
}
df.cachedRemoteRelationID = Some(df.plan.getRoot.getCachedRemoteRelation.getRelationId)
df
}

/**
Expand Down Expand Up @@ -3554,26 +3555,6 @@ class Dataset[T] private[sql] (
}
}

// Visible for testing
private[sql] var cachedRemoteRelationID: Option[String] = None

override def finalize(): Unit = {
if (!sparkSession.client.channel.isShutdown) {
cachedRemoteRelationID.foreach { dfId =>
try {
sparkSession.execute {
sparkSession.newCommand { builder =>
builder.getRemoveCachedRemoteRelationCommandBuilder
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build())
}
}
} catch {
case e: Throwable => logWarning("RemoveRemoteCachedRelation failed.", e)
}
}
}
}

/**
* We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We
* null out the instance for now.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
import org.apache.spark.sql.internal.{CatalogImpl, ContextCleaner, SqlApiConf}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -73,6 +73,11 @@ class SparkSession private[sql] (
with Logging {

private[this] val allocator = new RootAllocator()
private var shouldStopCleaner = false
private[sql] lazy val cleaner = {
shouldStopCleaner = true
new ContextCleaner(this)
}

// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
Expand Down Expand Up @@ -714,6 +719,9 @@ class SparkSession private[sql] (
if (releaseSessionOnClose) {
client.releaseSession()
}
if (shouldStopCleaner) {
cleaner.stop()
}
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal

import java.lang.ref.{ReferenceQueue, WeakReference}
import java.util.Collections
import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledThreadPoolExecutor, TimeUnit}

import com.google.common.util.concurrent.ThreadFactoryBuilder

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)

/**
* An asynchronous cleaner for objects.
*
* This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed
* when the associated object goes out of scope of the application. Actual cleanup is performed in
* a separate daemon thread.
*/
private[spark] class ContextCleaner(session: SparkSession) extends Logging {

/**
* How often (seconds) to trigger a garbage collection in this JVM. This context cleaner
* triggers cleanups only when weak references are garbage collected. In long-running
* applications with large driver JVMs, where there is little memory pressure on the driver,
* this may happen very occasionally or not at all. Not cleaning at all may lead to executors
* running out of disk space after a while.
*/
private val periodicGCInterval: Long = 30 * 60
private val refQueuePollTimeout: Long = 100

/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
* have not been handled by the reference queue.
*/
private val referenceBuffer =
Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)

private val referenceQueue = new ReferenceQueue[AnyRef]

private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() }

private val periodicGCService: ScheduledExecutorService =
ContextCleaner.newDaemonSingleThreadScheduledExecutor(
"spark-connect-context-cleaner-periodic-gc")

@volatile private var started = false
@volatile private var stopped = false

/** Start the cleaner. */
def start(): Unit = {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Connect Context Cleaner")
cleaningThread.start()
periodicGCService.scheduleAtFixedRate(
() => System.gc(),
periodicGCInterval,
periodicGCInterval,
TimeUnit.SECONDS)
}

/**
* Stop the cleaning thread and wait until the thread has finished running its current task.
*/
def stop(): Unit = {
stopped = true
// Interrupt the cleaning thread, but wait until the current task has finished before
// doing so. This guards against the race condition where a cleaning thread may
// potentially clean similarly named variables created by a different SparkSession.
synchronized {
cleaningThread.interrupt()
}
cleaningThread.join()
periodicGCService.shutdown()
}

/** Register a CachedRemoteRelation for cleanup when it is garbage collected. */
def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = {
registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
if (!started) {
// Lazily starts when the first cleanup is registered.
start()
started = true
}
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
}

/** Keep cleaning objects. */
private def keepCleaning(): Unit = {
while (!stopped && !session.client.channel.isShutdown) {
try {
val reference = Option(referenceQueue.remove(refQueuePollTimeout))
.map(_.asInstanceOf[CleanupTaskWeakReference])
// Synchronize here to avoid being interrupted on stop()
synchronized {
reference.foreach { ref =>
logDebug("Got cleaning task " + ref.task)
referenceBuffer.remove(ref)
ref.task match {
case CleanupCachedRemoteRelation(dfID) =>
doCleanupCachedRemoteRelation(dfID)
}
}
}
} catch {
case e: Throwable => logError("Error in cleaning thread", e)
}
}
}

/** Perform CleanupCachedRemoteRelation cleanup. */
private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = {
session.execute {
session.newCommand { builder =>
builder.getRemoveCachedRemoteRelationCommandBuilder
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build())
}
}
}
}

private object ContextCleaner {

/**
* Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads.
*/
private def newDaemonSingleThreadScheduledExecutor(
threadName: String): ScheduledExecutorService = {
val threadFactory =
new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
// By default, a cancelled task is not automatically removed from the work queue until its delay
// elapses. We have to enable it manually.
executor.setRemoveOnCancelPolicy(true)
executor
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import java.io.{ByteArrayOutputStream, PrintStream}
import java.lang.ref.WeakReference
import java.nio.file.Files
import java.time.DateTimeException
import java.util.Properties
Expand All @@ -30,6 +31,8 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester
import org.scalatest.concurrent.Eventually.{eventually, interval, timeout}
import org.scalatest.exceptions.TestFailedDueToTimeoutException

import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
Expand Down Expand Up @@ -1564,27 +1567,65 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
test("checkpoint gc") {
var df1 = spark.range(100).localCheckpoint(eager = true)
val encoder = df1.agnosticEncoder
val dfId = df1.cachedRemoteRelationID.get
val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId

// GC triggers remove the cached remote relation
df1 = null
System.gc()
val ref = new WeakReference[Object](df1)
while (ref.get() != null) { Thread.sleep(1000L); System.gc() }

// Make sure the cleanup happens in the server side.
Thread.sleep(3000L)
eventually(timeout(30.seconds), interval(500.millis)) {
val ex = intercept[SparkException] {
spark
.newDataset(encoder) { builder =>
builder.setCachedRemoteRelation(
proto.CachedRemoteRelation
.newBuilder()
.setRelationId(dfId)
.build())
}
.collect()
}
assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found"))
}
}

val ex = intercept[SparkException] {
spark
.newDataset(encoder) { builder =>
builder.setCachedRemoteRelation(
proto.CachedRemoteRelation
.newBuilder()
.setRelationId(dfId)
.build())
}
.collect()
test("checkpoint gc derived DataFrame") {
var df1 = spark.range(100).localCheckpoint(eager = true)
var derived = df1.repartition(10)
val encoder = df1.agnosticEncoder
val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId

df1 = null
val ref = new WeakReference[Object](df1)
while (ref.get() != null) { Thread.sleep(1000L); System.gc() }

def condition(): Unit = {
val ex = intercept[SparkException] {
spark
.newDataset(encoder) { builder =>
builder.setCachedRemoteRelation(
proto.CachedRemoteRelation
.newBuilder()
.setRelationId(dfId)
.build())
}
.collect()
}
assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found"))
}
assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found"))

intercept[TestFailedDueToTimeoutException] {
eventually(timeout(5.seconds), interval(500.millis))(condition())
}

// GC triggers remove the cached remote relation
derived = null
val ref1 = new WeakReference[Object](df1)
while (ref1.get() != null) { Thread.sleep(1000L); System.gc() }

// Check the state was removed up on garbage-collection.
eventually(timeout(30.seconds), interval(500.millis))(condition())
}
}

Expand Down

0 comments on commit 9e6faa2

Please sign in to comment.