Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48370][CONNECT] Checkpoint and localCheckpoint in Scala Spark Connect client #46683

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3402,20 +3402,105 @@ class Dataset[T] private[sql] (
df
}

def checkpoint(): Dataset[T] = {
throw new UnsupportedOperationException("checkpoint is not implemented.")
}
/**
* Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to
* truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
* where the plan may grow exponentially. It will be saved to files inside the checkpoint
* directory set with `SparkContext#setCheckpointDir`.
*
* @group basic
* @since 4.0.0
*/
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)

def checkpoint(eager: Boolean): Dataset[T] = {
throw new UnsupportedOperationException("checkpoint is not implemented.")
}
/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
* logical plan of this Dataset, which is especially useful in iterative algorithms where the
* plan may grow exponentially. It will be saved to files inside the checkpoint directory set
* with `SparkContext#setCheckpointDir`.
*
* @param eager
* Whether to checkpoint this dataframe immediately
*
* @note
* When checkpoint is used with eager = false, the final data that is checkpointed after the
* first action may be different from the data that was used during the job due to
* non-determinism of the underlying operation and retries. If checkpoint is used to achieve
* saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
* only deterministic after the first execution, after the checkpoint was finalized.
*
* @group basic
* @since 4.0.0
*/
def checkpoint(eager: Boolean): Dataset[T] =
checkpoint(eager = eager, reliableCheckpoint = true)

def localCheckpoint(): Dataset[T] = {
throw new UnsupportedOperationException("localCheckpoint is not implemented.")
}
/**
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used
* to truncate the logical plan of this Dataset, which is especially useful in iterative
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
* storage and despite potentially faster they are unreliable and may compromise job completion.
*
* @group basic
* @since 4.0.0
*/
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)

/**
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to
* truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
* where the plan may grow exponentially. Local checkpoints are written to executor storage and
* despite potentially faster they are unreliable and may compromise job completion.
*
* @param eager
* Whether to checkpoint this dataframe immediately
*
* @note
* When checkpoint is used with eager = false, the final data that is checkpointed after the
* first action may be different from the data that was used during the job due to
* non-determinism of the underlying operation and retries. If checkpoint is used to achieve
* saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
* only deterministic after the first execution, after the checkpoint was finalized.
*
* @group basic
* @since 4.0.0
*/
def localCheckpoint(eager: Boolean): Dataset[T] =
checkpoint(eager = eager, reliableCheckpoint = false)

def localCheckpoint(eager: Boolean): Dataset[T] = {
throw new UnsupportedOperationException("localCheckpoint is not implemented.")
/**
* Returns a checkpointed version of this Dataset.
*
* @param eager
* Whether to checkpoint this dataframe immediately
* @param reliableCheckpoint
* Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If
* false creates a local checkpoint using the caching subsystem
*/
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
val command = sparkSession.newCommand { builder =>
builder.getCheckpointCommandBuilder
.setLocal(reliableCheckpoint)
.setEager(eager)
.setRelation(this.plan.getRoot)
}
val responseIter = sparkSession.execute(command)
try {
val response = responseIter
.find(_.hasCheckpointCommandResult)
.getOrElse(throw new RuntimeException("CheckpointCommandResult must be present"))

val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation
sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation)

// Update the builder with the values from the result.
builder.setCachedRemoteRelation(cachedRemoteRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -73,6 +73,11 @@ class SparkSession private[sql] (
with Logging {

private[this] val allocator = new RootAllocator()
private var shouldStopCleaner = false
private[sql] lazy val cleaner = {
shouldStopCleaner = true
new SessionCleaner(this)
}

// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
Expand Down Expand Up @@ -714,6 +719,9 @@ class SparkSession private[sql] (
if (releaseSessionOnClose) {
client.releaseSession()
}
if (shouldStopCleaner) {
cleaner.stop()
}
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal

import java.lang.ref.{ReferenceQueue, WeakReference}
import java.util.Collections
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Classes that represent cleaning tasks.
*/
private sealed trait CleanupTask
private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
*
* When the referent object becomes only weakly reachable, the corresponding
* CleanupTaskWeakReference is automatically added to the given reference queue.
*/
private class CleanupTaskWeakReference(
val task: CleanupTask,
referent: AnyRef,
referenceQueue: ReferenceQueue[AnyRef])
extends WeakReference(referent, referenceQueue)

/**
* An asynchronous cleaner for objects.
*
* This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed
* when the associated object goes out of scope of the application. Actual cleanup is performed in
* a separate daemon thread.
*/
private[sql] class SessionCleaner(session: SparkSession) extends Logging {

/**
* How often (seconds) to trigger a garbage collection in this JVM. This context cleaner
* triggers cleanups only when weak references are garbage collected. In long-running
* applications with large driver JVMs, where there is little memory pressure on the driver,
* this may happen very occasionally or not at all. Not cleaning at all may lead to executors
* running out of disk space after a while.
*/
private val refQueuePollTimeout: Long = 100

/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
* have not been handled by the reference queue.
*/
private val referenceBuffer =
Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)

private val referenceQueue = new ReferenceQueue[AnyRef]

private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() }

@volatile private var started = false
@volatile private var stopped = false

/** Start the cleaner. */
def start(): Unit = {
cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Connect Context Cleaner")
cleaningThread.start()
}

/**
* Stop the cleaning thread and wait until the thread has finished running its current task.
*/
def stop(): Unit = {
stopped = true
// Interrupt the cleaning thread, but wait until the current task has finished before
// doing so. This guards against the race condition where a cleaning thread may
// potentially clean similarly named variables created by a different SparkSession.
synchronized {
cleaningThread.interrupt()
}
cleaningThread.join()
}

/** Register a CachedRemoteRelation for cleanup when it is garbage collected. */
def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = {
registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId))
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
if (!started) {
// Lazily starts when the first cleanup is registered.
start()
started = true
}
referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue))
}

/** Keep cleaning objects. */
private def keepCleaning(): Unit = {
while (!stopped && !session.client.channel.isShutdown) {
try {
val reference = Option(referenceQueue.remove(refQueuePollTimeout))
.map(_.asInstanceOf[CleanupTaskWeakReference])
// Synchronize here to avoid being interrupted on stop()
synchronized {
reference.foreach { ref =>
logDebug("Got cleaning task " + ref.task)
referenceBuffer.remove(ref)
ref.task match {
case CleanupCachedRemoteRelation(dfID) =>
doCleanupCachedRemoteRelation(dfID)
}
}
}
} catch {
case e: Throwable => logError("Error in cleaning thread", e)
}
}
}

/** Perform CleanupCachedRemoteRelation cleanup. */
private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = {
session.execute {
session.newCommand { builder =>
builder.getRemoveCachedRemoteRelationCommandBuilder
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build())
}
}
}
}
Loading