Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42859][CONNECT][PS] Basic support for pandas API on Spark Connect #40525

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f68a794
Initial Commit
itholic Mar 22, 2023
f7c68d5
Fix tox.ini
itholic Mar 23, 2023
3d35a7d
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Mar 27, 2023
ef440d8
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Mar 28, 2023
445e1f0
Fix mypy
itholic Mar 29, 2023
b02f179
Fix flake
itholic Mar 29, 2023
46a33cc
add more tests
itholic Mar 29, 2023
d1fa388
Add tests to sparktestsupport
itholic Mar 29, 2023
eeeca3e
Skip failed tests
itholic Mar 29, 2023
8784c57
Enable Spark Connect testing for internal frame
itholic Mar 29, 2023
a7ac189
Add type annotation
itholic Mar 29, 2023
0f08451
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Mar 29, 2023
da1a8b7
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Mar 29, 2023
db258df
Fix more tests
itholic Mar 30, 2023
1bf4527
Fix test
itholic Mar 30, 2023
b1e7656
move tests to pyspark-connect
itholic Mar 31, 2023
db69a42
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Apr 1, 2023
639da58
SparkDataFrame -> GenericDataFrame / SparkColumn -> GenericColumn
itholic Apr 3, 2023
b5a8d82
Renaming
itholic Apr 3, 2023
b1b5daa
Restore
itholic Apr 5, 2023
99f628f
Legacy -> PySpark
itholic Apr 6, 2023
72bf7ac
Applied comments
itholic Apr 6, 2023
dbc4e50
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Apr 6, 2023
397bb9f
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Apr 6, 2023
8588538
resolve conflicts
itholic Apr 7, 2023
cd7f48e
Merge branch 'master' of https://github.com/apache/spark into initial…
itholic Apr 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,68 @@ def __hash__(self):
# ml unittests
"pyspark.ml.tests.connect.test_connect_function",
"pyspark.ml.tests.connect.test_parity_torch_distributor",
# pandas-on-Spark unittests
zhengruifeng marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyspark-connect takes 2~3 hours after adding the PS related unittests, I guess we can add new modules pyspark-connect-pandas and pyspark-connect-pandas-slow @itholic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me address within follow-up PR. Thanks!

"pyspark.pandas.tests.connect.data_type_ops.test_parity_base",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_binary_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_boolean_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_categorical_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_complex_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_date_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_datetime_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_null_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_string_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_udt_ops",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_timedelta_ops",
"pyspark.pandas.tests.connect.indexes.test_parity_category",
"pyspark.pandas.tests.connect.indexes.test_parity_timedelta",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot_matplotlib",
"pyspark.pandas.tests.connect.plot.test_parity_frame_plot_plotly",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot_matplotlib",
"pyspark.pandas.tests.connect.plot.test_parity_series_plot_plotly",
"pyspark.pandas.tests.connect.test_parity_categorical",
"pyspark.pandas.tests.connect.test_parity_config",
"pyspark.pandas.tests.connect.test_parity_csv",
"pyspark.pandas.tests.connect.test_parity_dataframe_conversion",
"pyspark.pandas.tests.connect.test_parity_dataframe_spark_io",
"pyspark.pandas.tests.connect.test_parity_default_index",
"pyspark.pandas.tests.connect.test_parity_expanding",
"pyspark.pandas.tests.connect.test_parity_extension",
"pyspark.pandas.tests.connect.test_parity_frame_spark",
"pyspark.pandas.tests.connect.test_parity_generic_functions",
"pyspark.pandas.tests.connect.test_parity_indexops_spark",
"pyspark.pandas.tests.connect.test_parity_internal",
"pyspark.pandas.tests.connect.test_parity_namespace",
"pyspark.pandas.tests.connect.test_parity_numpy_compat",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_expanding",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
"pyspark.pandas.tests.connect.test_parity_repr",
"pyspark.pandas.tests.connect.test_parity_resample",
"pyspark.pandas.tests.connect.test_parity_reshape",
"pyspark.pandas.tests.connect.test_parity_rolling",
"pyspark.pandas.tests.connect.test_parity_scalars",
"pyspark.pandas.tests.connect.test_parity_series_conversion",
"pyspark.pandas.tests.connect.test_parity_series_datetime",
"pyspark.pandas.tests.connect.test_parity_series_string",
"pyspark.pandas.tests.connect.test_parity_spark_functions",
"pyspark.pandas.tests.connect.test_parity_sql",
"pyspark.pandas.tests.connect.test_parity_typedef",
"pyspark.pandas.tests.connect.test_parity_utils",
"pyspark.pandas.tests.connect.test_parity_window",
"pyspark.pandas.tests.connect.indexes.test_parity_base",
"pyspark.pandas.tests.connect.indexes.test_parity_datetime",
"pyspark.pandas.tests.connect.test_parity_dataframe",
"pyspark.pandas.tests.connect.test_parity_dataframe_slow",
"pyspark.pandas.tests.connect.test_parity_groupby",
"pyspark.pandas.tests.connect.test_parity_groupby_slow",
"pyspark.pandas.tests.connect.test_parity_indexing",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_slow",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby",
"pyspark.pandas.tests.connect.test_parity_series",
"pyspark.pandas.tests.connect.test_parity_stats",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
1 change: 1 addition & 0 deletions dev/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ per-file-ignores =
python/pyspark/ml/tests/*.py: F403,
python/pyspark/mllib/tests/*.py: F403,
python/pyspark/pandas/tests/*.py: F401 F403,
python/pyspark/pandas/tests/connect/*.py: F401 F403,
python/pyspark/resource/tests/*.py: F403,
python/pyspark/sql/tests/*.py: F403,
python/pyspark/streaming/tests/*.py: F403,
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import numpy as np
from pandas.api.extensions import ExtensionDtype

from pyspark.sql.column import Column as PySparkColumn
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame


if TYPE_CHECKING:
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.frame import DataFrame
Expand Down Expand Up @@ -49,3 +55,7 @@

DataFrameOrSeries = Union["DataFrame", "Series"]
SeriesOrIndex = Union["Series", "Index"]

# For Spark Connect compatibility.
GenericColumn = Union[PySparkColumn, ConnectColumn]
GenericDataFrame = Union[PySparkDataFrame, ConnectDataFrame]
2 changes: 1 addition & 1 deletion python/pyspark/pandas/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def attach_id_column(self, id_type: str, column: Name) -> "DataFrame":
for scol, label in zip(internal.data_spark_columns, internal.column_labels)
]
)
sdf = attach_func(sdf, name_like_string(column))
sdf = attach_func(sdf, name_like_string(column)) # type: ignore[assignment]

return DataFrame(
InternalFrame(
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.sql.types import LongType, BooleanType, NumericType

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex
from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex, GenericColumn
from pyspark.pandas.config import get_option, option_context
from pyspark.pandas.internal import (
InternalField,
Expand Down Expand Up @@ -67,7 +67,7 @@ def should_alignment_for_column_op(self: SeriesOrIndex, other: SeriesOrIndex) ->


def align_diff_index_ops(
func: Callable[..., Column], this_index_ops: SeriesOrIndex, *args: Any
func: Callable[..., GenericColumn], this_index_ops: SeriesOrIndex, *args: Any
) -> SeriesOrIndex:
"""
Align the `IndexOpsMixin` objects and apply the function.
Expand Down Expand Up @@ -178,7 +178,7 @@ def align_diff_index_ops(
).rename(that_series.name)


def booleanize_null(scol: Column, f: Callable[..., Column]) -> Column:
def booleanize_null(scol: GenericColumn, f: Callable[..., GenericColumn]) -> GenericColumn:
"""
Booleanize Null in Spark Column
"""
Expand All @@ -190,12 +190,12 @@ def booleanize_null(scol: Column, f: Callable[..., Column]) -> Column:
if f in comp_ops:
# if `f` is "!=", fill null with True otherwise False
filler = f == Column.__ne__
scol = F.when(scol.isNull(), filler).otherwise(scol)
scol = F.when(scol.isNull(), filler).otherwise(scol) # type: ignore[arg-type]

return scol


def column_op(f: Callable[..., Column]) -> Callable[..., SeriesOrIndex]:
def column_op(f: Callable[..., GenericColumn]) -> Callable[..., SeriesOrIndex]:
"""
A decorator that wraps APIs taking/returning Spark Column so that pandas-on-Spark Series can be
supported too. If this decorator is used for the `f` function that takes Spark Column and
Expand Down Expand Up @@ -225,7 +225,7 @@ def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
)

field = InternalField.from_struct_field(
self._internal.spark_frame.select(scol).schema[0],
self._internal.spark_frame.select(scol).schema[0], # type: ignore[arg-type]
use_extension_dtypes=any(
isinstance(col.dtype, extension_dtypes) for col in [self] + cols
),
Expand All @@ -252,7 +252,7 @@ def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
return wrapper


def numpy_column_op(f: Callable[..., Column]) -> Callable[..., SeriesOrIndex]:
def numpy_column_op(f: Callable[..., GenericColumn]) -> Callable[..., SeriesOrIndex]:
@wraps(f)
def wrapper(self: SeriesOrIndex, *args: Any) -> SeriesOrIndex:
# PySpark does not support NumPy type out of the box. For now, we convert NumPy types
Expand Down Expand Up @@ -287,7 +287,7 @@ def _psdf(self) -> DataFrame:

@abstractmethod
def _with_new_scol(
self: IndexOpsLike, scol: Column, *, field: Optional[InternalField] = None
self: IndexOpsLike, scol: GenericColumn, *, field: Optional[InternalField] = None
) -> IndexOpsLike:
pass

Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/pandas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ def get_option(key: str, default: Union[Any, _NoValueType] = _NoValue) -> Any:
if default is _NoValue:
default = _options_dict[key].default
_options_dict[key].validate(default)
spark_session = default_session()

return json.loads(default_session().conf.get(_key_format(key), default=json.dumps(default)))
return json.loads(spark_session.conf.get(_key_format(key), default=json.dumps(default)))


def set_option(key: str, value: Any) -> None:
Expand All @@ -386,8 +387,9 @@ def set_option(key: str, value: Any) -> None:
"""
_check_option(key)
_options_dict[key].validate(value)
spark_session = default_session()

default_session().conf.set(_key_format(key), json.dumps(value))
spark_session.conf.set(_key_format(key), json.dumps(value))


def reset_option(key: str) -> None:
Expand Down
16 changes: 11 additions & 5 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import numbers
from abc import ABCMeta
from itertools import chain
from typing import Any, Optional, Union
from typing import cast, Callable, Any, Optional, Union

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

from pyspark.sql import functions as F, Column
from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import (
ArrayType,
BinaryType,
Expand All @@ -44,7 +44,7 @@
TimestampNTZType,
UserDefinedType,
)
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex, GenericColumn
from pyspark.pandas.typedef import extension_dtypes
from pyspark.pandas.typedef.typehints import (
extension_dtypes_available,
Expand All @@ -53,6 +53,10 @@
spark_type_to_pandas_dtype,
)

itholic marked this conversation as resolved.
Show resolved Hide resolved
# For supporting Spark Connect
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.utils import is_remote

if extension_dtypes_available:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype

Expand Down Expand Up @@ -470,14 +474,16 @@ def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
else:
from pyspark.pandas.base import column_op

return column_op(Column.__eq__)(left, right)
Column = ConnectColumn if is_remote() else PySparkColumn
return column_op(cast(Callable[..., GenericColumn], Column.__eq__))(left, right)
itholic marked this conversation as resolved.
Show resolved Hide resolved

def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)

return column_op(Column.__ne__)(left, right)
Column = ConnectColumn if is_remote() else PySparkColumn
return column_op(cast(Callable[..., GenericColumn], Column.__ne__))(left, right)
itholic marked this conversation as resolved.
Show resolved Hide resolved

def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/pandas/data_type_ops/binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
#

from typing import Any, Union, cast
from typing import Any, Union, cast, Callable

from pandas.api.types import CategoricalDtype

from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex, GenericColumn
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_categorical_type,
Expand All @@ -46,9 +46,9 @@ def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)

if isinstance(right, IndexOpsMixin) and isinstance(right.spark.data_type, BinaryType):
return column_op(F.concat)(left, right)
return column_op(cast(Callable[..., GenericColumn], F.concat))(left, right)
elif isinstance(right, bytes):
return column_op(F.concat)(left, F.lit(right))
return column_op(cast(Callable[..., GenericColumn], F.concat))(left, F.lit(right))
else:
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name
Expand All @@ -71,26 +71,26 @@ def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:

_sanitize_list_like(right)

return column_op(Column.__lt__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__lt__))(left, right)

def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)

return column_op(Column.__le__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__le__))(left, right)

def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)
return column_op(Column.__ge__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__ge__))(left, right)

def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op

_sanitize_list_like(right)
return column_op(Column.__gt__)(left, right)
return column_op(cast(Callable[..., GenericColumn], Column.__gt__))(left, right)

def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
Expand Down
Loading