Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49609][PYTHON][CONNECT] Add API compatibility check between Classic and Connect #48085

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __hash__(self):
"pyspark.sql.tests.test_resources",
"pyspark.sql.tests.plot.test_frame_plot",
"pyspark.sql.tests.plot.test_frame_plot_plotly",
"pyspark.sql.tests.test_connect_compatibility",
],
)

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame:
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sparkSession)

def filter(self, condition: "ColumnOrName") -> ParentDataFrame:
def filter(self, condition: Union[Column, str]) -> ParentDataFrame:
if isinstance(condition, str):
jdf = self._jdf.filter(condition)
elif isinstance(condition, Column):
Expand Down Expand Up @@ -1809,10 +1809,10 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ign
def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
return self.dropDuplicates(subset)

def writeTo(self, table: str) -> DataFrameWriterV2:
def writeTo(self, table: str) -> "DataFrameWriterV2":
return DataFrameWriterV2(self, table)

def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter:
def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter":
return MergeIntoWriter(self, table, condition)

def pandas_api(
Expand Down
26 changes: 15 additions & 11 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
...

def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData:
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]

Expand Down Expand Up @@ -570,7 +570,7 @@ def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
...

def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
_cols: List[Column] = []
for c in cols:
if isinstance(c, Column):
Expand Down Expand Up @@ -731,8 +731,8 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column:
session=self._session,
)

def limit(self, n: int) -> ParentDataFrame:
res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
def limit(self, num: int) -> ParentDataFrame:
itholic marked this conversation as resolved.
Show resolved Hide resolved
res = DataFrame(plan.Limit(child=self._plan, limit=num), session=self._session)
res._cached_schema = self._cached_schema
return res

Expand Down Expand Up @@ -931,7 +931,11 @@ def _show_string(
)._to_table()
return table[0][0].as_py()

def withColumns(self, colsMap: Dict[str, Column]) -> ParentDataFrame:
def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
# Below code is to help enable kwargs in future.
assert len(colsMap) == 1
colsMap = colsMap[0] # type: ignore[assignment]

if not isinstance(colsMap, dict):
raise PySparkTypeError(
errorClass="NOT_DICT",
Expand Down Expand Up @@ -1256,7 +1260,7 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
res._cached_schema = self._merge_cached_schema(other)
return res

def where(self, condition: Union[Column, str]) -> ParentDataFrame:
itholic marked this conversation as resolved.
Show resolved Hide resolved
def where(self, condition: "ColumnOrName") -> ParentDataFrame:
if not isinstance(condition, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down Expand Up @@ -2193,18 +2197,18 @@ def cb(ei: "ExecutionInfo") -> None:

return DataFrameWriterV2(self._plan, self._session, table, cb)

def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter:
def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter":
def cb(ei: "ExecutionInfo") -> None:
self._execution_info = ei

return MergeIntoWriter(
self._plan, self._session, table, condition, cb # type: ignore[arg-type]
)

def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)
def offset(self, num: int) -> ParentDataFrame:
itholic marked this conversation as resolved.
Show resolved Hide resolved
return DataFrame(plan.Offset(child=self._plan, offset=num), session=self._session)

def checkpoint(self, eager: bool = True) -> "DataFrame":
def checkpoint(self, eager: bool = True) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
Expand All @@ -2214,7 +2218,7 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
def localCheckpoint(self, eager: bool = True) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.udtf import UDTFRegistration
from pyspark.sql.datasource import DataSourceRegistration
from pyspark.sql.dataframe import DataFrame as ParentDataFrame

# Running MyPy type checks will always require pandas and
# other dependencies so importing here is fine.
Expand Down Expand Up @@ -1641,7 +1642,7 @@ def prepare(obj: Any) -> Any:

def sql(
self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None, **kwargs: Any
) -> DataFrame:
) -> "ParentDataFrame":
"""Returns a :class:`DataFrame` representing the result of the given query.
When ``kwargs`` is specified, this method formats the given string by using the Python
standard formatter. The method binds named parameters to SQL literals or
Expand Down
188 changes: 188 additions & 0 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#
itholic marked this conversation as resolved.
Show resolved Hide resolved
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
import inspect

from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.classic.column import Column as ClassicColumn
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.session import SparkSession as ClassicSparkSession
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession


class ConnectCompatibilityTestsMixin:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic can we add some test to verify that this keeps functioning? Basically we need some a couple of broken signatures.

def get_public_methods(self, cls):
"""Get public methods of a class."""
return {
name: method
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction)
if not name.startswith("_")
}

def get_public_properties(self, cls):
"""Get public properties of a class."""
return {
name: member
for name, member in inspect.getmembers(cls)
if isinstance(member, property) and not name.startswith("_")
}

def test_signature_comparison_between_classic_and_connect(self):
def compare_method_signatures(classic_cls, connect_cls, cls_name):
"""Compare method signatures between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)

common_methods = set(classic_methods.keys()) & set(connect_methods.keys())
itholic marked this conversation as resolved.
Show resolved Hide resolved

for method in common_methods:
classic_signature = inspect.signature(classic_methods[method])
connect_signature = inspect.signature(connect_methods[method])

# createDataFrame cannot be the same since RDD is not supported from Spark Connect
if not method == "createDataFrame":
self.assertEqual(
classic_signature,
connect_signature,
f"Signature mismatch in {cls_name} method '{method}'\n"
f"Classic: {classic_signature}\n"
f"Connect: {connect_signature}",
)

# DataFrame API signature comparison
compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame")

# Column API signature comparison
compare_method_signatures(ClassicColumn, ConnectColumn, "Column")
itholic marked this conversation as resolved.
Show resolved Hide resolved

# SparkSession API signature comparison
compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession")

def test_property_comparison_between_classic_and_connect(self):
def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties):
"""Compare properties between classic and connect classes."""
classic_properties = self.get_public_properties(classic_cls)
connect_properties = self.get_public_properties(connect_cls)

# Identify missing properties
classic_only_properties = set(classic_properties.keys()) - set(
connect_properties.keys()
)

# Compare the actual missing properties with the expected ones
self.assertEqual(
classic_only_properties,
expected_missing_properties,
f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}",
)

# Expected missing properties for DataFrame
expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"}

# DataFrame properties comparison
compare_property_lists(
ClassicDataFrame,
ConnectDataFrame,
"DataFrame",
expected_missing_properties_for_dataframe,
)

# Expected missing properties for Column (if any, replace with actual values)
expected_missing_properties_for_column = set()

# Column properties comparison
compare_property_lists(
ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column
)

# Expected missing properties for SparkSession
expected_missing_properties_for_spark_session = {"sparkContext", "version"}

# SparkSession properties comparison
compare_property_lists(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_properties_for_spark_session,
)

def test_missing_methods(self):
def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods):
"""Check for expected missing methods between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)

# Identify missing methods
classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys())

# Compare the actual missing methods with the expected ones
self.assertEqual(
classic_only_methods,
expected_missing_methods,
f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}",
)

# Expected missing methods for DataFrame
expected_missing_methods_for_dataframe = {
"inputFiles",
"isLocal",
"semanticHash",
"isEmpty",
}

# DataFrame missing method check
check_missing_methods(
ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe
)

# Expected missing methods for Column (if any, replace with actual values)
expected_missing_methods_for_column = set()

# Column missing method check
check_missing_methods(
ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column
)

# Expected missing methods for SparkSession (if any, replace with actual values)
expected_missing_methods_for_spark_session = {"newSession"}

# SparkSession missing method check
check_missing_methods(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_methods_for_spark_session,
)


class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.sql.tests.test_connect_compatibility import * # noqa: F401

try:
import xmlrunner # type: ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)