Skip to content

Commit

Permalink
followup
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 22, 2024
1 parent a1e27a3 commit 21f2d40
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 55 deletions.
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "Dat
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

session = SparkSession.active()
return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
plan=plan.CachedRemoteRelation(relation.relation_id, session),
session=session,
)
38 changes: 0 additions & 38 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

# mypy: disable-error-code="override"
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.errors.exceptions.base import (
SessionNotSameException,
PySparkIndexError,
Expand Down Expand Up @@ -138,41 +137,6 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._cached_remote_relation_id: Optional[str] = None

def __del__(self) -> None:
# If session is already closed, all cached DataFrame should be released.
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
try:
command = plan.RemoveRemoteCachedRelation(
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
).command(session=self._session.client)
req = self._session.client._execute_plan_request_with_metadata()
if self._session.client._user_id:
req.user_context.user_id = self._session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in self._session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = self._session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = self._session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -2137,7 +2101,6 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
Expand All @@ -2146,7 +2109,6 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame":
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

if not is_remote_only():
Expand Down
54 changes: 46 additions & 8 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import pickle
from threading import Lock
from inspect import signature, isclass
import warnings

import pyarrow as pa

Expand All @@ -49,6 +50,7 @@

import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.expressions import Expression
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
Expand All @@ -62,6 +64,7 @@
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.session import SparkSession


class LogicalPlan:
Expand Down Expand Up @@ -547,14 +550,49 @@ class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""

def __init__(self, relationId: str):
def __init__(self, relation_id: str, spark_session: "SparkSession"):
super().__init__(None)
self._relationId = relationId

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relationId
return plan
self._relation_id = relation_id
# Needs to hold the session to make a request itself.
self._spark_session = spark_session

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation_id
return plan

def __del__(self) -> None:
session = self._spark_session
# If session is already closed, all cached DataFrame should be released.
if session is not None and not session.client.is_closed and self._relation_id is not None:
try:
command = RemoveRemoteCachedRelation(self).command(session=session.client)
req = session.client._execute_plan_request_with_metadata()
if session.client._user_id:
req.user_context.user_id = session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")


class Hint(LogicalPlan):
Expand Down Expand Up @@ -1792,7 +1830,7 @@ def __init__(self, relation: CachedRemoteRelation) -> None:

def command(self, session: "SparkConnectClient") -> proto.Command:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation._relationId
plan.cached_remote_relation.relation_id = self._relation._relation_id
cmd = proto.Command()
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
return cmd
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def _create_remote_dataframe(self, remote_id: str) -> "ParentDataFrame":
This is used in ForeachBatch() runner, where the remote DataFrame refers to the
output of a micro batch.
"""
return DataFrame(CachedRemoteRelation(remote_id), self)
return DataFrame(CachedRemoteRelation(remote_id, spark_session=self), self)

@staticmethod
def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
Expand Down
53 changes: 47 additions & 6 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#

import os
import gc
import unittest
import shutil
import tempfile
import time

from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
Expand All @@ -34,6 +34,7 @@
ArrayType,
Row,
)
from pyspark.testing.utils import eventually
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.testing.connectutils import (
should_test_connect,
Expand Down Expand Up @@ -1379,8 +1380,8 @@ def test_garbage_collection_checkpoint(self):
# SPARK-48258: Make sure garbage-collecting DataFrame remove the paired state
# in Spark Connect server
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._cached_remote_relation_id)
cached_remote_relation_id = df._cached_remote_relation_id
self.assertIsNotNone(df._plan._relation_id)
cached_remote_relation_id = df._plan._relation_id

jvm = self.spark._jvm
session_holder = getattr(
Expand All @@ -1397,14 +1398,54 @@ def test_garbage_collection_checkpoint(self):
)

del df
gc.collect()

time.sleep(3) # Make sure removing is triggered, and executed in the server.
def condition():
# Check the state was removed up on garbage-collection.
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

eventually(catch_assertions=True)(condition)()

def test_garbage_collection_derived_checkpoint(self):
# SPARK-48258: Should keep the cached remote relation when derived DataFrames exist
df = self.connect.range(10).localCheckpoint()
self.assertIsNotNone(df._plan._relation_id)
derived = df.repartition(10)
cached_remote_relation_id = df._plan._relation_id

# Check the state was removed up on garbage-collection.
self.assertIsNone(
jvm = self.spark._jvm
session_holder = getattr(
getattr(
jvm.org.apache.spark.sql.connect.service,
"SparkConnectService$",
),
"MODULE$",
).getOrCreateIsolatedSession(self.connect.client._user_id, self.connect.client._session_id)

# Check the state exists.
self.assertIsNotNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

del df
gc.collect()

def condition():
self.assertIsNone(
session_holder.dataFrameCache().getOrDefault(cached_remote_relation_id, None)
)

# Should not remove the cache
with self.assertRaises(AssertionError):
eventually(catch_assertions=True, timeout=5)(condition)()

del derived
gc.collect()

eventually(catch_assertions=True)(condition)()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down

0 comments on commit 21f2d40

Please sign in to comment.