Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35929][PYTHON] Support to infer nested dict as a struct when creating a DataFrame #33214

Closed
wants to merge 9 commits into from
13 changes: 9 additions & 4 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def _inferSchemaFromList(self, data, names=None):
"""
if not data:
raise ValueError("can not infer schema from empty dataset")
schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct)
for row in data))
if _has_nulltype(schema):
raise ValueError("Some of types cannot be determined after inferring")
return schema
Expand All @@ -462,11 +464,13 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
raise ValueError("The first row in RDD is empty, "
"can not infer schema")

infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
if samplingRatio is None:
schema = _infer_schema(first, names=names)
schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(row, names=names))
schema = _merge_type(schema, _infer_schema(
row, names=names, infer_dict_as_struct=infer_dict_as_struct))
if not _has_nulltype(schema):
break
else:
Expand All @@ -475,7 +479,8 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
else:
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
schema = rdd.map(lambda row: _infer_schema(
row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type)
return schema

def _createFromRDD(self, rdd, schema, samplingRatio):
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())

def test_infer_nested_dict_as_struct(self):
# SPARK-35929: Test inferring nested dict as a struct type.
NestedRow = Row("f1", "f2")

with self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}):
data = [NestedRow([{"payment": 200.5, "name": "A"}], [1, 2]),
NestedRow([{"payment": 100.5, "name": "B"}], [2, 3])]

nestedRdd = self.sc.parallelize(data)
df = self.spark.createDataFrame(nestedRdd)
self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first())

df = self.spark.createDataFrame(data)
self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.first())

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{'a': 1}], ["b"])
self.assertEqual(df.columns, ['b'])
Expand Down
26 changes: 17 additions & 9 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _int_size_to_type(size):
_array_type_mappings['u'] = StringType


def _infer_type(obj):
def _infer_type(obj, infer_dict_as_struct=False):
"""Infer the DataType from obj
"""
if obj is None:
Expand All @@ -1020,14 +1020,22 @@ def _infer_type(obj):
return dataType()

if isinstance(obj, dict):
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
return MapType(NullType(), NullType(), True)
if infer_dict_as_struct:
struct = StructType()
for key, value in obj.items():
if key is not None and value is not None:
struct.add(key, _infer_type(value, infer_dict_as_struct), True)
return struct
else:
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key, infer_dict_as_struct),
_infer_type(value, infer_dict_as_struct), True)
return MapType(NullType(), NullType(), True)
Comment on lines +1030 to +1034
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to log warning if inferred value types are not inconsistent? We can recommend users to use the config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! :)
Actually PySpark merging one only handles null cases only (that's called out here) at

def _merge_type(a, b, name=None):
if name is None:
new_msg = lambda msg: msg
new_name = lambda n: "field %s" % n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_name = lambda n: "field %s in %s" % (n, name)
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
# same type
if isinstance(a, StructType):
nfs = dict((f.name, f.dataType) for f in b.fields)
fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()),
name=new_name(f.name)))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
if n not in names:
fields.append(StructField(n, nfs[n]))
return StructType(fields)
elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.elementType, b.elementType,
name='element in array %s' % name), True)
elif isinstance(a, MapType):
return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name),
_merge_type(a.valueType, b.valueType, name='value of map %s' % name),
True)
else:
return a

It actually fails for different types (unlike JSON or CSV type inference).
I am not sure what's the ideal behavior for the null case pointed out here though.
Let me separate it from this PR in any event if you're fine.

elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
return ArrayType(_infer_type(obj[0], infer_dict_as_struct), True)
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
Expand All @@ -1036,12 +1044,12 @@ def _infer_type(obj):
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
return _infer_schema(obj, infer_dict_as_struct=infer_dict_as_struct)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))


def _infer_schema(row, names=None):
def _infer_schema(row, names=None, infer_dict_as_struct=False):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
Expand All @@ -1067,7 +1075,7 @@ def _infer_schema(row, names=None):
fields = []
for k, v in items:
try:
fields.append(StructField(k, _infer_type(v), True))
fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), True))
except TypeError as e:
raise TypeError("Unable to infer the type of the field {}.".format(k)) from e
return StructType(fields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,13 @@ object SQLConf {
.intConf
.createWithDefault(0)

val INFER_NESTED_DICT_AS_STRUCT = buildConf("spark.sql.pyspark.inferNestedDictAsStruct.enabled")
.doc("PySpark's SparkSession.createDataFrame infers the nested dict as a map by default. " +
"When it set to true, it infers the nested dict as a struct.")
.version("3.3.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -4040,6 +4047,8 @@ class SQLConf extends Serializable with Logging {

def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)

def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down