Skip to content

Commit

Permalink
Checkpoint and localCheckpoint in Spark Connect
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 14, 2024
1 parent d9ff78e commit 6a3f14d
Show file tree
Hide file tree
Showing 19 changed files with 678 additions and 354 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ message AnalyzePlanRequest {
Persist persist = 14;
Unpersist unpersist = 15;
GetStorageLevel get_storage_level = 16;
Checkpoint checkpoint = 18;
}

message Schema {
Expand Down Expand Up @@ -199,6 +200,17 @@ message AnalyzePlanRequest {
// (Required) The logical plan to get the storage level.
Relation relation = 1;
}

message Checkpoint {
// (Required) The logical plan to checkpoint.
Relation relation = 1;

// (Optional) Is it localCheckpoint
optional bool local = 2;

// (Optional) Is it eager
optional bool eager = 3;
}
}

// Response to performing analysis of the query. Contains relevant metadata to be able to
Expand All @@ -224,6 +236,7 @@ message AnalyzePlanResponse {
Persist persist = 12;
Unpersist unpersist = 13;
GetStorageLevel get_storage_level = 14;
Checkpoint checkpoint = 16;
}

message Schema {
Expand Down Expand Up @@ -275,6 +288,11 @@ message AnalyzePlanResponse {
// (Required) The StorageLevel as a result of get_storage_level request.
StorageLevel storage_level = 1;
}

message Checkpoint {
// (Required) The logical plan checkpointed.
CachedRemoteRelation relation = 1;
}
}

// A request to be executed by the service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message Command {
StreamingQueryListenerBusCommand streaming_query_listener_bus_command = 11;
CommonInlineUserDefinedDataSource register_data_source = 12;
CreateResourceProfileCommand create_resource_profile_command = 13;
RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 14;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -484,3 +485,9 @@ message CreateResourceProfileCommandResult {
// (Required) Server-side generated resource profile id.
int32 profile_id = 1;
}

// Command to remove `CashedRemoteRelation`
message RemoveCachedRemoteRelationCommand {
// (Required) ID of the remote related (assigned by the service).
CachedRemoteRelation relation = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.SESSION_ID
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
Expand Down Expand Up @@ -2581,6 +2581,8 @@ class SparkConnectPlanner(
handleCreateResourceProfileCommand(
command.getCreateResourceProfileCommand,
responseObserver)
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
Expand Down Expand Up @@ -3507,6 +3509,14 @@ class SparkConnectPlanner(
.build())
}

private def handleRemoveCachedRemoteRelationCommand(
removeCachedRemoteRelationCommand: proto.RemoveCachedRemoteRelationCommand): Unit = {
val dfId = removeCachedRemoteRelationCommand.getRelation.getRelationId
logInfo(log"Removing DataFrame with id ${MDC(DATAFRAME_ID, dfId)} from the cache")
sessionHolder.removeCachedDataFrame(dfId)
executeHolder.eventsManager.postFinished()
}

private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio

// Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like
// foreachBatch() in Streaming. Lazy since most sessions don't need it.
private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()
private[spark] lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()

// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.connect.service

import java.util.UUID

import scala.jdk.CollectionConverters._

import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.DATAFRAME_ID
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
Expand Down Expand Up @@ -206,6 +209,29 @@ private[connect] class SparkConnectAnalyzeHandler(
.setStorageLevel(StorageLevelProtoConverter.toConnectProtoType(storageLevel))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.CHECKPOINT =>
val target = Dataset
.ofRows(session, planner.transformRelation(request.getCheckpoint.getRelation))
val checkpointed = if (request.getCheckpoint.hasLocal && request.getCheckpoint.hasEager) {
target.localCheckpoint(eager = request.getCheckpoint.getEager)
} else if (request.getCheckpoint.hasLocal) {
target.localCheckpoint()
} else if (request.getCheckpoint.hasEager) {
target.checkpoint(eager = request.getCheckpoint.getEager)
} else {
target.checkpoint()
}

val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
sessionHolder.cacheDataFrameById(dfId, checkpointed)

builder.setCheckpoint(
proto.AnalyzePlanResponse.Checkpoint
.newBuilder()
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build())
.build())

case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ object SparkConnectService extends Logging {
previoslyObservedSessionId)
}

// For testing
private[spark] def getOrCreateIsolatedSession(
userId: String, sessionId: String): SessionHolder = {
getOrCreateIsolatedSession(userId, sessionId, None)
}

/**
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@
class SparkFrameMethodsParityTests(
SparkFrameMethodsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on checkpoint which is not supported from Spark Connect.")
@unittest.skip("Test depends on SparkContext which is not supported from Spark Connect.")
def test_checkpoint(self):
super().test_checkpoint()

@unittest.skip("Test depends on RDD which is not supported from Spark Connect.")
def test_coalesce(self):
super().test_coalesce()

@unittest.skip("Test depends on localCheckpoint which is not supported from Spark Connect.")
def test_local_checkpoint(self):
super().test_local_checkpoint()

@unittest.skip("Test depends on RDD which is not supported from Spark Connect.")
def test_repartition(self):
super().test_repartition()
Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
from pyspark.sql.connect.conversion import (
storage_level_to_proto,
proto_to_storage_level,
proto_to_remote_cached_dataframe,
)
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
Expand Down Expand Up @@ -100,6 +104,7 @@
from google.rpc.error_details_pb2 import ErrorInfo
from pyspark.sql.connect._typing import DataTypeOrString
from pyspark.sql.datasource import DataSource
from pyspark.sql.connect.dataframe import DataFrame


class ChannelBuilder:
Expand Down Expand Up @@ -528,6 +533,7 @@ def __init__(
is_same_semantics: Optional[bool],
semantic_hash: Optional[int],
storage_level: Optional[StorageLevel],
replaced: Optional["DataFrame"],
):
self.schema = schema
self.explain_string = explain_string
Expand All @@ -540,6 +546,7 @@ def __init__(
self.is_same_semantics = is_same_semantics
self.semantic_hash = semantic_hash
self.storage_level = storage_level
self.replaced = replaced

@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
Expand All @@ -554,6 +561,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
is_same_semantics: Optional[bool] = None
semantic_hash: Optional[int] = None
storage_level: Optional[StorageLevel] = None
replaced: Optional["DataFrame"] = None

if pb.HasField("schema"):
schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
Expand Down Expand Up @@ -581,6 +589,8 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
pass
elif pb.HasField("get_storage_level"):
storage_level = proto_to_storage_level(pb.get_storage_level.storage_level)
elif pb.HasField("checkpoint"):
replaced = proto_to_remote_cached_dataframe(pb.checkpoint.relation)
else:
raise SparkConnectException("No analyze result found!")

Expand All @@ -596,6 +606,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
is_same_semantics,
semantic_hash,
storage_level,
replaced,
)


Expand Down Expand Up @@ -1229,6 +1240,12 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
req.unpersist.blocking = cast(bool, kwargs.get("blocking"))
elif method == "get_storage_level":
req.get_storage_level.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
elif method == "checkpoint":
req.checkpoint.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
if kwargs.get("local", None) is not None:
req.checkpoint.local = cast(bool, kwargs.get("local"))
if kwargs.get("eager", None) is not None:
req.checkpoint.eager = cast(bool, kwargs.get("eager"))
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@
import pyspark.sql.connect.proto as pb2
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names

from typing import (
Any,
Callable,
Sequence,
List,
)
from typing import Any, Callable, Sequence, List, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame


class LocalDataToArrowConversion:
Expand Down Expand Up @@ -570,3 +568,16 @@ def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)


def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "DataFrame":
assert relation is not None and isinstance(relation, pb2.CachedRemoteRelation)

from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
)
Loading

0 comments on commit 6a3f14d

Please sign in to comment.