From 78d24d90444080707fff1857de124fafb6274bc2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 21 May 2024 19:12:00 +0900 Subject: [PATCH 1/5] Checkpoint and localCheckpoint in Scala Spark Connect client --- .../scala/org/apache/spark/sql/Dataset.scala | 130 ++++++++++++++++-- .../apache/spark/sql/ClientE2ETestSuite.scala | 32 +++++ .../connect/client/SparkConnectClient.scala | 2 +- 3 files changed, 151 insertions(+), 13 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 37f770319b695..673583005c213 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -42,6 +43,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils +// scalastyle:off no.finalize /** * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in * parallel using functional or relational operations. Each Dataset also has an untyped view @@ -132,7 +134,8 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends Serializable { + extends Serializable + with Logging { // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) @@ -3402,20 +3405,103 @@ class Dataset[T] private[sql] ( df } - def checkpoint(): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) - def checkpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint directory set + * with `SparkContext#setCheckpointDir`. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = true) - def localCheckpoint(): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") - } + /** + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used + * to truncate the logical plan of this Dataset, which is especially useful in iterative + * algorithms where the plan may grow exponentially. Local checkpoints are written to executor + * storage and despite potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) - def localCheckpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. Local checkpoints are written to executor storage and + * despite potentially faster they are unreliable and may compromise job completion. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = false) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint + * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If + * false creates a local checkpoint using the caching subsystem + */ + private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + val df = sparkSession.newDataset(agnosticEncoder) { builder => + val command = sparkSession.newCommand { builder => + builder.getCheckpointCommandBuilder + .setLocal(reliableCheckpoint) + .setEager(eager) + .setRelation(this.plan.getRoot) + } + val responseIter = sparkSession.execute(command) + try { + val response = responseIter + .find(_.hasCheckpointCommandResult) + .getOrElse(throw new RuntimeException("CheckpointCommandResult must be present")) + // Update the builder with the values from the result. + builder.setCachedRemoteRelation(response.getCheckpointCommandResult.getRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } + } + df.cachedRemoteRelationID = Some(df.plan.getRoot.getCachedRemoteRelation.getRelationId) + df } /** @@ -3468,6 +3554,26 @@ class Dataset[T] private[sql] ( } } + // Visible for testing + private[sql] var cachedRemoteRelationID: Option[String] = None + + override def finalize(): Unit = { + if (!sparkSession.client.channel.isShutdown) { + cachedRemoteRelationID.foreach { dfId => + try { + sparkSession.execute { + sparkSession.newCommand { builder => + builder.getRemoveCachedRemoteRelationCommandBuilder + .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build()) + } + } + } catch { + case e: Throwable => logWarning("RemoveRemoteCachedRelation failed.", e) + } + } + } + } + /** * We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We * null out the instance for now. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 255dd76697987..b75dfb2b4e49f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -1558,6 +1559,37 @@ class ClientE2ETestSuite val metrics = SparkThreadUtils.awaitResult(future, 2.seconds) assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98)) } + + test("checkpoint") { + val df = spark.range(100).localCheckpoint() + testCapturedStdOut(df.explain(), "ExistingRDD") + } + + test("checkpoint gc") { + var df1 = spark.range(100).localCheckpoint(eager = true) + val encoder = df1.agnosticEncoder + val dfId = df1.cachedRemoteRelationID.get + + // GC triggers remove the cached remote relation + df1 = null + System.gc() + + // Make sure the cleanup happens in the server side. + Thread.sleep(3000L) + + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 1e7b4e6574ddb..b5eda024bfb3c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon */ private[sql] class SparkConnectClient( private[sql] val configuration: SparkConnectClient.Configuration, - private val channel: ManagedChannel) { + private[sql] val channel: ManagedChannel) { private val userContext: UserContext = configuration.userContext From 84a8dedbc3675038ff310faebdea21a94052e1c7 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 May 2024 11:26:06 +0900 Subject: [PATCH 2/5] Uses weakreferences and cleaner --- .../scala/org/apache/spark/sql/Dataset.scala | 31 +-- .../org/apache/spark/sql/SparkSession.scala | 10 +- .../spark/sql/internal/ContextCleaner.scala | 176 ++++++++++++++++++ .../apache/spark/sql/ClientE2ETestSuite.scala | 71 +++++-- 4 files changed, 247 insertions(+), 41 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 673583005c213..4d68ac7fb6e93 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -43,7 +43,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SparkClassUtils -// scalastyle:off no.finalize /** * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in * parallel using functional or relational operations. Each Dataset also has an untyped view @@ -3481,7 +3480,7 @@ class Dataset[T] private[sql] ( * false creates a local checkpoint using the caching subsystem */ private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { - val df = sparkSession.newDataset(agnosticEncoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => builder.getCheckpointCommandBuilder .setLocal(reliableCheckpoint) @@ -3493,15 +3492,17 @@ class Dataset[T] private[sql] ( val response = responseIter .find(_.hasCheckpointCommandResult) .getOrElse(throw new RuntimeException("CheckpointCommandResult must be present")) + + val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation + sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation) + // Update the builder with the values from the result. - builder.setCachedRemoteRelation(response.getCheckpointCommandResult.getRelation) + builder.setCachedRemoteRelation(cachedRemoteRelation) } finally { // consume the rest of the iterator responseIter.foreach(_ => ()) } } - df.cachedRemoteRelationID = Some(df.plan.getRoot.getCachedRemoteRelation.getRelationId) - df } /** @@ -3554,26 +3555,6 @@ class Dataset[T] private[sql] ( } } - // Visible for testing - private[sql] var cachedRemoteRelationID: Option[String] = None - - override def finalize(): Unit = { - if (!sparkSession.client.channel.isShutdown) { - cachedRemoteRelationID.foreach { dfId => - try { - sparkSession.execute { - sparkSession.newCommand { builder => - builder.getRemoveCachedRemoteRelationCommandBuilder - .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build()) - } - } - } catch { - case e: Throwable => logWarning("RemoveRemoteCachedRelation failed.", e) - } - } - } - } - /** * We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We * null out the instance for now. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1188fba60a2fe..5c539c611e85d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, ContextCleaner, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -73,6 +73,11 @@ class SparkSession private[sql] ( with Logging { private[this] val allocator = new RootAllocator() + private var shouldStopCleaner = false + private[sql] lazy val cleaner = { + shouldStopCleaner = true + new ContextCleaner(this) + } // a unique session ID for this session from client. private[sql] def sessionId: String = client.sessionId @@ -714,6 +719,9 @@ class SparkSession private[sql] ( if (releaseSessionOnClose) { client.releaseSession() } + if (shouldStopCleaner) { + cleaner.stop() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala new file mode 100644 index 0000000000000..c4148b0d7a4d1 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.Collections +import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledThreadPoolExecutor, TimeUnit} + +import com.google.common.util.concurrent.ThreadFactoryBuilder + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession + +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for objects. + * + * This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed + * when the associated object goes out of scope of the application. Actual cleanup is performed in + * a separate daemon thread. + */ +private[spark] class ContextCleaner(session: SparkSession) extends Logging { + + /** + * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner + * triggers cleanups only when weak references are garbage collected. In long-running + * applications with large driver JVMs, where there is little memory pressure on the driver, + * this may happen very occasionally or not at all. Not cleaning at all may lead to executors + * running out of disk space after a while. + */ + private val periodicGCInterval: Long = 30 * 60 + private val refQueuePollTimeout: Long = 100 + + /** + * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they + * have not been handled by the reference queue. + */ + private val referenceBuffer = + Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() } + + private val periodicGCService: ScheduledExecutorService = + ContextCleaner.newDaemonSingleThreadScheduledExecutor( + "spark-connect-context-cleaner-periodic-gc") + + @volatile private var started = false + @volatile private var stopped = false + + /** Start the cleaner. */ + def start(): Unit = { + cleaningThread.setDaemon(true) + cleaningThread.setName("Spark Connect Context Cleaner") + cleaningThread.start() + periodicGCService.scheduleAtFixedRate( + () => System.gc(), + periodicGCInterval, + periodicGCInterval, + TimeUnit.SECONDS) + } + + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ + def stop(): Unit = { + stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkSession. + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() + periodicGCService.shutdown() + } + + /** Register a CachedRemoteRelation for cleanup when it is garbage collected. */ + def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = { + registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId)) + } + + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { + if (!started) { + // Lazily starts when the first cleanup is registered. + start() + started = true + } + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) + } + + /** Keep cleaning objects. */ + private def keepCleaning(): Unit = { + while (!stopped && !session.client.channel.isShutdown) { + try { + val reference = Option(referenceQueue.remove(refQueuePollTimeout)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.foreach { ref => + logDebug("Got cleaning task " + ref.task) + referenceBuffer.remove(ref) + ref.task match { + case CleanupCachedRemoteRelation(dfID) => + doCleanupCachedRemoteRelation(dfID) + } + } + } + } catch { + case e: Throwable => logError("Error in cleaning thread", e) + } + } + } + + /** Perform CleanupCachedRemoteRelation cleanup. */ + private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = { + session.execute { + session.newCommand { builder => + builder.getRemoveCachedRemoteRelationCommandBuilder + .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build()) + } + } + } +} + +private object ContextCleaner { + + /** + * Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads. + */ + private def newDaemonSingleThreadScheduledExecutor( + threadName: String): ScheduledExecutorService = { + val threadFactory = + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + val executor = new ScheduledThreadPoolExecutor(1, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index b75dfb2b4e49f..5de6c70faf1f1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, PrintStream} +import java.lang.ref.WeakReference import java.nio.file.Files import java.time.DateTimeException import java.util.Properties @@ -30,6 +31,8 @@ import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} +import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} @@ -1568,27 +1571,65 @@ class ClientE2ETestSuite test("checkpoint gc") { var df1 = spark.range(100).localCheckpoint(eager = true) val encoder = df1.agnosticEncoder - val dfId = df1.cachedRemoteRelationID.get + val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId // GC triggers remove the cached remote relation df1 = null - System.gc() + val ref = new WeakReference[Object](df1) + while (ref.get() != null) { Thread.sleep(1000L); System.gc() } - // Make sure the cleanup happens in the server side. - Thread.sleep(3000L) + eventually(timeout(30.seconds), interval(500.millis)) { + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + } - val ex = intercept[SparkException] { - spark - .newDataset(encoder) { builder => - builder.setCachedRemoteRelation( - proto.CachedRemoteRelation - .newBuilder() - .setRelationId(dfId) - .build()) - } - .collect() + test("checkpoint gc derived DataFrame") { + var df1 = spark.range(100).localCheckpoint(eager = true) + var derived = df1.repartition(10) + val encoder = df1.agnosticEncoder + val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId + + df1 = null + val ref = new WeakReference[Object](df1) + while (ref.get() != null) { Thread.sleep(1000L); System.gc() } + + def condition(): Unit = { + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) } - assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + + intercept[TestFailedDueToTimeoutException] { + eventually(timeout(5.seconds), interval(500.millis))(condition()) + } + + // GC triggers remove the cached remote relation + derived = null + val ref1 = new WeakReference[Object](df1) + while (ref1.get() != null) { Thread.sleep(1000L); System.gc() } + + // Check the state was removed up on garbage-collection. + eventually(timeout(30.seconds), interval(500.millis))(condition()) } } From acbaf6a57123fec2beacc3100db726eacb43e881 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 May 2024 13:02:47 +0900 Subject: [PATCH 3/5] Simpler --- .../org/apache/spark/sql/SparkSession.scala | 4 +-- ...textCleaner.scala => SessionCleaner.scala} | 34 ++----------------- 2 files changed, 4 insertions(+), 34 deletions(-) rename connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/{ContextCleaner.scala => SessionCleaner.scala} (80%) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5c539c611e85d..91ee0f52e8bd0 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, ContextCleaner, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -76,7 +76,7 @@ class SparkSession private[sql] ( private var shouldStopCleaner = false private[sql] lazy val cleaner = { shouldStopCleaner = true - new ContextCleaner(this) + new SessionCleaner(this) } // a unique session ID for this session from client. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala similarity index 80% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala rename to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala index c4148b0d7a4d1..f78e5cf4c4ed8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ContextCleaner.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.internal import java.lang.ref.{ReferenceQueue, WeakReference} import java.util.Collections -import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, ScheduledThreadPoolExecutor, TimeUnit} - -import com.google.common.util.concurrent.ThreadFactoryBuilder +import java.util.concurrent.ConcurrentHashMap import org.apache.spark.connect.proto import org.apache.spark.internal.Logging @@ -52,7 +50,7 @@ private class CleanupTaskWeakReference( * when the associated object goes out of scope of the application. Actual cleanup is performed in * a separate daemon thread. */ -private[spark] class ContextCleaner(session: SparkSession) extends Logging { +private[spark] class SessionCleaner(session: SparkSession) extends Logging { /** * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner @@ -61,7 +59,6 @@ private[spark] class ContextCleaner(session: SparkSession) extends Logging { * this may happen very occasionally or not at all. Not cleaning at all may lead to executors * running out of disk space after a while. */ - private val periodicGCInterval: Long = 30 * 60 private val refQueuePollTimeout: Long = 100 /** @@ -75,10 +72,6 @@ private[spark] class ContextCleaner(session: SparkSession) extends Logging { private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() } - private val periodicGCService: ScheduledExecutorService = - ContextCleaner.newDaemonSingleThreadScheduledExecutor( - "spark-connect-context-cleaner-periodic-gc") - @volatile private var started = false @volatile private var stopped = false @@ -87,11 +80,6 @@ private[spark] class ContextCleaner(session: SparkSession) extends Logging { cleaningThread.setDaemon(true) cleaningThread.setName("Spark Connect Context Cleaner") cleaningThread.start() - periodicGCService.scheduleAtFixedRate( - () => System.gc(), - periodicGCInterval, - periodicGCInterval, - TimeUnit.SECONDS) } /** @@ -106,7 +94,6 @@ private[spark] class ContextCleaner(session: SparkSession) extends Logging { cleaningThread.interrupt() } cleaningThread.join() - periodicGCService.shutdown() } /** Register a CachedRemoteRelation for cleanup when it is garbage collected. */ @@ -157,20 +144,3 @@ private[spark] class ContextCleaner(session: SparkSession) extends Logging { } } } - -private object ContextCleaner { - - /** - * Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads. - */ - private def newDaemonSingleThreadScheduledExecutor( - threadName: String): ScheduledExecutorService = { - val threadFactory = - new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - val executor = new ScheduledThreadPoolExecutor(1, threadFactory) - // By default, a cancelled task is not automatically removed from the work queue until its delay - // elapses. We have to enable it manually. - executor.setRemoveOnCancelPolicy(true) - executor - } -} From eb6bdafaa2ddf1fc5221ef9d3c0c94938c13b6e2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 May 2024 14:34:04 +0900 Subject: [PATCH 4/5] fix mima complaints --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +--- .../org/apache/spark/sql/internal/SessionCleaner.scala | 2 +- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 6 +++--- .../client/CheckConnectJvmClientCompatibility.scala | 10 ++++++++++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 4d68ac7fb6e93..fc9766357cb22 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -28,7 +28,6 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -133,8 +132,7 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends Serializable - with Logging { + extends Serializable { // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala index f78e5cf4c4ed8..036ea4a84fa97 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala @@ -50,7 +50,7 @@ private class CleanupTaskWeakReference( * when the associated object goes out of scope of the application. Actual cleanup is performed in * a separate daemon thread. */ -private[spark] class SessionCleaner(session: SparkSession) extends Logging { +private[sql] class SessionCleaner(session: SparkSession) extends Logging { /** * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 5de6c70faf1f1..bd6966d169327 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -1578,7 +1578,7 @@ class ClientE2ETestSuite val ref = new WeakReference[Object](df1) while (ref.get() != null) { Thread.sleep(1000L); System.gc() } - eventually(timeout(30.seconds), interval(500.millis)) { + eventually(timeout(60.seconds), interval(1.second)) { val ex = intercept[SparkException] { spark .newDataset(encoder) { builder => @@ -1620,7 +1620,7 @@ class ClientE2ETestSuite } intercept[TestFailedDueToTimeoutException] { - eventually(timeout(5.seconds), interval(500.millis))(condition()) + eventually(timeout(5.seconds), interval(1.second))(condition()) } // GC triggers remove the cached remote relation @@ -1629,7 +1629,7 @@ class ClientE2ETestSuite while (ref1.get() != null) { Thread.sleep(1000L); System.gc() } // Check the state was removed up on garbage-collection. - eventually(timeout(30.seconds), interval(500.millis))(condition()) + eventually(timeout(60.seconds), interval(1.second))(condition()) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 374d8464deebf..2e4bbab8d3a41 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -334,6 +334,16 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[ReversedMissingMethodProblem]( "org.apache.spark.sql.SQLImplicits._sqlContext" // protected ), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.SessionCleaner"), + + // private + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.CleanupTask"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupTaskWeakReference"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation$"), // Catalyst Refactoring ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), From 71774ddd3a72a759111acb74656c801c339d1088 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 May 2024 19:19:36 +0900 Subject: [PATCH 5/5] Separate suite --- .../apache/spark/sql/CheckpointSuite.scala | 117 ++++++++++++++++++ .../apache/spark/sql/ClientE2ETestSuite.scala | 73 ----------- 2 files changed, 117 insertions(+), 73 deletions(-) create mode 100644 connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala new file mode 100644 index 0000000000000..e57b051890f56 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.{ByteArrayOutputStream, PrintStream} + +import scala.concurrent.duration.DurationInt + +import org.apache.commons.io.output.TeeOutputStream +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} + +class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { + + private def captureStdOut(block: => Unit): String = { + val currentOut = Console.out + val capturedOut = new ByteArrayOutputStream() + val newOut = new PrintStream(new TeeOutputStream(currentOut, capturedOut)) + Console.withOut(newOut) { + block + } + capturedOut.toString + } + + private def checkFragments(result: String, fragmentsToCheck: Seq[String]): Unit = { + fragmentsToCheck.foreach { fragment => + assert(result.contains(fragment)) + } + } + + private def testCapturedStdOut(block: => Unit, fragmentsToCheck: String*): Unit = { + checkFragments(captureStdOut(block), fragmentsToCheck) + } + + test("checkpoint") { + val df = spark.range(100).localCheckpoint() + testCapturedStdOut(df.explain(), "ExistingRDD") + } + + test("checkpoint gc") { + val df = spark.range(100).localCheckpoint(eager = true) + val encoder = df.agnosticEncoder + val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId + spark.cleaner.doCleanupCachedRemoteRelation(dfId) + + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + // This test is flaky because cannot guarantee GC + // You can locally run this to verify the behavior. + ignore("checkpoint gc derived DataFrame") { + var df1 = spark.range(100).localCheckpoint(eager = true) + var derived = df1.repartition(10) + val encoder = df1.agnosticEncoder + val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId + + df1 = null + System.gc() + Thread.sleep(3000L) + + def condition(): Unit = { + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + intercept[TestFailedDueToTimeoutException] { + eventually(timeout(5.seconds), interval(1.second))(condition()) + } + + // GC triggers remove the cached remote relation + derived = null + System.gc() + Thread.sleep(3000L) + + // Check the state was removed up on garbage-collection. + eventually(timeout(60.seconds), interval(1.second))(condition()) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index bd6966d169327..255dd76697987 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, PrintStream} -import java.lang.ref.WeakReference import java.nio.file.Files import java.time.DateTimeException import java.util.Properties @@ -31,12 +30,9 @@ import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester -import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} -import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} -import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -1562,75 +1558,6 @@ class ClientE2ETestSuite val metrics = SparkThreadUtils.awaitResult(future, 2.seconds) assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98)) } - - test("checkpoint") { - val df = spark.range(100).localCheckpoint() - testCapturedStdOut(df.explain(), "ExistingRDD") - } - - test("checkpoint gc") { - var df1 = spark.range(100).localCheckpoint(eager = true) - val encoder = df1.agnosticEncoder - val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId - - // GC triggers remove the cached remote relation - df1 = null - val ref = new WeakReference[Object](df1) - while (ref.get() != null) { Thread.sleep(1000L); System.gc() } - - eventually(timeout(60.seconds), interval(1.second)) { - val ex = intercept[SparkException] { - spark - .newDataset(encoder) { builder => - builder.setCachedRemoteRelation( - proto.CachedRemoteRelation - .newBuilder() - .setRelationId(dfId) - .build()) - } - .collect() - } - assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) - } - } - - test("checkpoint gc derived DataFrame") { - var df1 = spark.range(100).localCheckpoint(eager = true) - var derived = df1.repartition(10) - val encoder = df1.agnosticEncoder - val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId - - df1 = null - val ref = new WeakReference[Object](df1) - while (ref.get() != null) { Thread.sleep(1000L); System.gc() } - - def condition(): Unit = { - val ex = intercept[SparkException] { - spark - .newDataset(encoder) { builder => - builder.setCachedRemoteRelation( - proto.CachedRemoteRelation - .newBuilder() - .setRelationId(dfId) - .build()) - } - .collect() - } - assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) - } - - intercept[TestFailedDueToTimeoutException] { - eventually(timeout(5.seconds), interval(1.second))(condition()) - } - - // GC triggers remove the cached remote relation - derived = null - val ref1 = new WeakReference[Object](df1) - while (ref1.get() != null) { Thread.sleep(1000L); System.gc() } - - // Check the state was removed up on garbage-collection. - eventually(timeout(60.seconds), interval(1.second))(condition()) - } } private[sql] case class ClassData(a: String, b: Int)