Skip to content

Commit

Permalink
[SPARK-48310][PYTHON][CONNECT] Cached properties must return copies
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
When a consumer modifies the result values of a cached property it will modify the value of the cached property.

Before:
```python
df_columns = df.columns
for col in ['id', 'name']:
  df_columns.remove(col)
assert len(df_columns) == df.columns
```

But this is wrong and this patch fixes it to

```python
df_columns = df.columns
for col in ['id', 'name']:
  df_columns.remove(col)
assert len(df_columns) != df.columns
```

### Why are the changes needed?
Correctness of the API

### Does this PR introduce _any_ user-facing change?
No, this makes the code consistent with Spark classic.

### How was this patch tested?
UT

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #46621 from grundprinzip/grundprinzip/SPARK-48310.

Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
grundprinzip authored and HyukjinKwon committed May 17, 2024
1 parent 153053f commit 05e1706
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Type,
)

import copy
import sys
import random
import pyarrow as pa
Expand Down Expand Up @@ -1787,7 +1788,7 @@ def schema(self) -> StructType:
if self._cached_schema is None:
query = self._plan.to_proto(self._session.client)
self._cached_schema = self._session.client.schema(query)
return self._cached_schema
return copy.deepcopy(self._cached_schema)

def isLocal(self) -> bool:
query = self._plan.to_proto(self._session.client)
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,37 @@

from pyspark.sql.tests.test_dataframe import DataFrameTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.sql.types import StructType, StructField, IntegerType, StringType


class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_help_command(self):
df = self.spark.createDataFrame(data=[{"foo": "bar"}, {"foo": "baz"}])
super().check_help_command(df)

def test_cached_property_is_copied(self):
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("city", StringType(), True),
]
)
# Create some dummy data
data = [
(1, "Alice", 30, "New York"),
(2, "Bob", 25, "San Francisco"),
(3, "Cathy", 29, "Los Angeles"),
(4, "David", 35, "Chicago"),
]
df = self.spark.createDataFrame(data, schema)
df_columns = df.columns
assert len(df.columns) == 4
for col in ["id", "name"]:
df_columns.remove(col)
assert len(df.columns) == 4

@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_toDF_with_schema_string(self):
super().test_toDF_with_schema_string()
Expand Down

0 comments on commit 05e1706

Please sign in to comment.