diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 08dce0fb95f1a..22c075ca36975 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -22,7 +22,7 @@ from typing import Optional from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer -from pyspark.sql.types import ArrayType, DataType, UserDefinedType +from pyspark.sql.types import ArrayType, DataType, UserDefinedType, StructType from pyspark.sql.pandas.types import to_arrow_type @@ -193,7 +193,7 @@ def create_array(s, t: pa.DataType, dt: Optional[DataType] = None): raise e return array - def create_arrs_names(s, t: pa.DataType, dt: Optional[DataType] = None): + def create_arrs_names(s, t: pa.DataType, dt: Optional[StructType] = None): # If input s is empty with zero columns, return empty Arrays with struct if len(s) == 0 and len(s.columns) == 0: return [(pa.array([], type=field.type), field.name) for field in t] @@ -240,6 +240,8 @@ def create_arrs_names(s, t: pa.DataType, dt: Optional[DataType] = None): raise ValueError("A field of type StructType expects a pandas.DataFrame, " "but got: %s" % str(type(s))) if isinstance(dt, DataType): + type_not_match = "dt must be instance of StructType when t is pyarrow struct" + assert isinstance(dt, StructType), type_not_match arrs_names = create_arrs_names(s, t, dt) else: arrs_names = create_arrs_names(s, t)