From 21f2d4063a79f8cefca2601655a3baa723b2ead0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 22 May 2024 10:10:36 +0900 Subject: [PATCH] followup --- python/pyspark/sql/connect/conversion.py | 5 +- python/pyspark/sql/connect/dataframe.py | 38 ------------- python/pyspark/sql/connect/plan.py | 54 ++++++++++++++++--- python/pyspark/sql/connect/session.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 53 +++++++++++++++--- 5 files changed, 97 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index b1cf88e40a4e8..1c205586d6096 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "Dat from pyspark.sql.connect.session import SparkSession import pyspark.sql.connect.plan as plan + session = SparkSession.active() return DataFrame( - plan=plan.CachedRemoteRelation(relation.relation_id), - session=SparkSession.active(), + plan=plan.CachedRemoteRelation(relation.relation_id, session), + session=session, ) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 3725bc3ba0e40..510776bb752d3 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -16,7 +16,6 @@ # # mypy: disable-error-code="override" -from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.errors.exceptions.base import ( SessionNotSameException, PySparkIndexError, @@ -138,41 +137,6 @@ def __init__( # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False self._cached_schema: Optional[StructType] = None - self._cached_remote_relation_id: Optional[str] = None - - def __del__(self) -> None: - # If session is already closed, all cached DataFrame should be released. - if not self._session.client.is_closed and self._cached_remote_relation_id is not None: - try: - command = plan.RemoveRemoteCachedRelation( - plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id) - ).command(session=self._session.client) - req = self._session.client._execute_plan_request_with_metadata() - if self._session.client._user_id: - req.user_context.user_id = self._session.client._user_id - req.plan.command.CopyFrom(command) - - for attempt in self._session.client._retrying(): - with attempt: - # !!HACK ALERT!! - # unary_stream does not work on Python's exit for an unknown reasons - # Therefore, here we open unary_unary channel instead. - # See also :class:`SparkConnectServiceStub`. - request_serializer = ( - spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString - ) - response_deserializer = ( - spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString - ) - channel = self._session.client._channel.unary_unary( - "/spark.connect.SparkConnectService/ExecutePlan", - request_serializer=request_serializer, - response_deserializer=response_deserializer, - ) - metadata = self._session.client._builder.metadata() - channel(req, metadata=metadata) # type: ignore[arg-type] - except Exception as e: - warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") def __reduce__(self) -> Tuple: """ @@ -2137,7 +2101,6 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) - checkpointed._cached_remote_relation_id = checkpointed._plan._relationId return checkpointed def localCheckpoint(self, eager: bool = True) -> "DataFrame": @@ -2146,7 +2109,6 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame": assert "checkpoint_command_result" in properties checkpointed = properties["checkpoint_command_result"] assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) - checkpointed._cached_remote_relation_id = checkpointed._plan._relationId return checkpointed if not is_remote_only(): diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 94c2641bb4d21..868bd4fb57aa4 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -40,6 +40,7 @@ import pickle from threading import Lock from inspect import signature, isclass +import warnings import pyarrow as pa @@ -49,6 +50,7 @@ import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column +from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.expressions import Expression from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType @@ -62,6 +64,7 @@ from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.observation import Observation + from pyspark.sql.connect.session import SparkSession class LogicalPlan: @@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan): """Logical plan object for a DataFrame reference which represents a DataFrame that's been cached on the server with a given id.""" - def __init__(self, relationId: str): + def __init__(self, relation_id: str, spark_session: "SparkSession"): super().__init__(None) - self._relationId = relationId - - def plan(self, session: "SparkConnectClient") -> proto.Relation: - plan = self._create_proto_relation() - plan.cached_remote_relation.relation_id = self._relationId - return plan + self._relation_id = relation_id + # Needs to hold the session to make a request itself. + self._spark_session = spark_session + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() + plan.cached_remote_relation.relation_id = self._relation_id + return plan + + def __del__(self) -> None: + session = self._spark_session + # If session is already closed, all cached DataFrame should be released. + if session is not None and not session.client.is_closed and self._relation_id is not None: + try: + command = RemoveRemoteCachedRelation(self).command(session=session.client) + req = session.client._execute_plan_request_with_metadata() + if session.client._user_id: + req.user_context.user_id = session.client._user_id + req.plan.command.CopyFrom(command) + + for attempt in session.client._retrying(): + with attempt: + # !!HACK ALERT!! + # unary_stream does not work on Python's exit for an unknown reasons + # Therefore, here we open unary_unary channel instead. + # See also :class:`SparkConnectServiceStub`. + request_serializer = ( + spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString + ) + response_deserializer = ( + spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString + ) + channel = session.client._channel.unary_unary( + "/spark.connect.SparkConnectService/ExecutePlan", + request_serializer=request_serializer, + response_deserializer=response_deserializer, + ) + metadata = session.client._builder.metadata() + channel(req, metadata=metadata) # type: ignore[arg-type] + except Exception as e: + warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") class Hint(LogicalPlan): @@ -1792,7 +1830,7 @@ def __init__(self, relation: CachedRemoteRelation) -> None: def command(self, session: "SparkConnectClient") -> proto.Command: plan = self._create_proto_relation() - plan.cached_remote_relation.relation_id = self._relation._relationId + plan.cached_remote_relation.relation_id = self._relation._relation_id cmd = proto.Command() cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation) return cmd diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 5e6c5e5587646..f99d298ea1170 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -926,7 +926,7 @@ def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame": This is used in ForeachBatch() runner, where the remote DataFrame refers to the output of a micro batch. """ - return DataFrame(CachedRemoteRelation(remote_id), self) + return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self) @staticmethod def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index b144c3b8de208..0648b5ce9925c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -16,10 +16,10 @@ # import os +import gc import unittest import shutil import tempfile -import time from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError @@ -34,6 +34,7 @@ ArrayType, Row, ) +from pyspark.testing.utils import eventually from pyspark.testing.sqlutils import SQLTestUtils from pyspark.testing.connectutils import ( should_test_connect, @@ -1379,8 +1380,8 @@ def test_garbage_collection_checkpoint(self): # SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state # in Spark Connect server df = self.connect.range(10).localCheckpoint() - self.assertIsNotNone(df._cached_remote_relation_id) - cached_remote_relation_id = df._cached_remote_relation_id + self.assertIsNotNone(df._plan._relation_id) + cached_remote_relation_id = df._plan._relation_id jvm = self.spark._jvm session_holder = getattr( @@ -1397,14 +1398,54 @@ def test_garbage_collection_checkpoint(self): ) del df + gc.collect() - time.sleep(3) # Make sure removing is triggered, and executed in the server. + def condition(): + # Check the state was removed up on garbage-collection. + self.assertIsNone( + session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) + ) + + eventually(catch_assertions=True)(condition)() + + def test_garbage_collection_derived_checkpoint(self): + # SPARK-48258: Should keep the cached remote relation when derived DataFrames exist + df = self.connect.range(10).localCheckpoint() + self.assertIsNotNone(df._plan._relation_id) + derived = df.repartition(10) + cached_remote_relation_id = df._plan._relation_id - # Check the state was removed up on garbage-collection. - self.assertIsNone( + jvm = self.spark._jvm + session_holder = getattr( + getattr( + jvm.org.apache.spark.sql.connect.service, + "SparkConnectService$", + ), + "MODULE$", + ).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id) + + # Check the state exists. + self.assertIsNotNone( session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) ) + del df + gc.collect() + + def condition(): + self.assertIsNone( + session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None) + ) + + # Should not remove the cache + with self.assertRaises(AssertionError): + eventually(catch_assertions=True, timeout=5)(condition)() + + del derived + gc.collect() + + eventually(catch_assertions=True)(condition)() + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401