Skip to content

Commit

Permalink
resolved comments
Browse files Browse the repository at this point in the history
  • Loading branch information
itholic committed Jul 6, 2021
1 parent a750b17 commit 0ce96fa
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
13 changes: 9 additions & 4 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def _inferSchemaFromList(self, data, names=None):
"""
if not data:
raise ValueError("can not infer schema from empty dataset")
schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
schema = reduce(_merge_type, (_infer_schema(row, names, infer_dict_as_struct)
for row in data))
if _has_nulltype(schema):
raise ValueError("Some of types cannot be determined after inferring")
return schema
Expand All @@ -462,11 +464,13 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
raise ValueError("The first row in RDD is empty, "
"can not infer schema")

infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()
if samplingRatio is None:
schema = _infer_schema(first, names=names)
schema = _infer_schema(first, names=names, infer_dict_as_struct=infer_dict_as_struct)
if _has_nulltype(schema):
for row in rdd.take(100)[1:]:
schema = _merge_type(schema, _infer_schema(row, names=names))
schema = _merge_type(schema, _infer_schema(
row, names=names, infer_dict_as_struct=infer_dict_as_struct))
if not _has_nulltype(schema):
break
else:
Expand All @@ -475,7 +479,8 @@ def _inferSchema(self, rdd, samplingRatio=None, names=None):
else:
if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
schema = rdd.map(lambda row: _infer_schema(
row, names, infer_dict_as_struct=infer_dict_as_struct)).reduce(_merge_type)
return schema

def _createFromRDD(self, rdd, schema, samplingRatio):
Expand Down
18 changes: 12 additions & 6 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,6 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(nestedRdd2)
self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])

with self.sql_conf({"spark.sql.pyspark.inferNestedStructByMap": False}):
nestedRdd3 = self.sc.parallelize([NestedRow([{"payment": 200.5, "name": "A"}], [1, 2]),
NestedRow([{"payment": 100.5, "name": "B"}], [2, 3])])
df = self.spark.createDataFrame(nestedRdd3)
self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.collect()[0])

from collections import namedtuple
CustomRow = namedtuple('CustomRow', 'field1 field2')
rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
Expand All @@ -210,6 +204,18 @@ def test_infer_nested_schema(self):
df = self.spark.createDataFrame(rdd)
self.assertEqual(Row(field1=1, field2=u'row1'), df.first())

def test_infer_nested_dict(self):
# SPARK-35929: Test inferring nested dict as a struct type.
NestedRow = Row("f1", "f2")

with self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}):
test = self.spark._wrapped._conf.inferDictAsStruct()
test1 = self.spark.conf.get("spark.sql.pyspark.inferNestedDictAsStruct.enabled")
nestedRdd = self.sc.parallelize([NestedRow([{"payment": 200.5, "name": "A"}], [1, 2]),
NestedRow([{"payment": 100.5, "name": "B"}], [2, 3])])
df = self.spark.createDataFrame(nestedRdd)
self.assertEqual(Row(f1=[Row(payment=200.5, name='A')], f2=[1, 2]), df.collect()[0])

def test_create_dataframe_from_dict_respects_schema(self):
df = self.spark.createDataFrame([{'a': 1}], ["b"])
self.assertEqual(df.columns, ['b'])
Expand Down
25 changes: 12 additions & 13 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _int_size_to_type(size):
_array_type_mappings['u'] = StringType


def _infer_type(obj):
def _infer_type(obj, infer_dict_as_struct=False):
"""Infer the DataType from obj
"""
if obj is None:
Expand All @@ -1020,23 +1020,22 @@ def _infer_type(obj):
return dataType()

if isinstance(obj, dict):
from pyspark.sql.session import SparkSession
if (SparkSession._activeSession.conf.get(
"spark.sql.pyspark.inferNestedStructByMap").lower() == "true"):
if infer_dict_as_struct:
struct = StructType()
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
return MapType(NullType(), NullType(), True)
struct.add(key, _infer_type(value, infer_dict_as_struct), True)
return struct
else:
struct = StructType()
for key, value in obj.items():
if key is not None and value is not None:
struct.add(key, _infer_type(value), True)
return struct
return MapType(_infer_type(key, infer_dict_as_struct),
_infer_type(value, infer_dict_as_struct), True)
return MapType(NullType(), NullType(), True)
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
return ArrayType(_infer_type(obj[0], infer_dict_as_struct), True)
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
Expand All @@ -1045,12 +1044,12 @@ def _infer_type(obj):
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
return _infer_schema(obj, infer_dict_as_struct=infer_dict_as_struct)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))


def _infer_schema(row, names=None):
def _infer_schema(row, names=None, infer_dict_as_struct=False):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
Expand All @@ -1076,7 +1075,7 @@ def _infer_schema(row, names=None):
fields = []
for k, v in items:
try:
fields.append(StructField(k, _infer_type(v), True))
fields.append(StructField(k, _infer_type(v, infer_dict_as_struct), True))
except TypeError as e:
raise TypeError("Unable to infer the type of the field {}.".format(k)) from e
return StructType(fields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3327,12 +3327,11 @@ object SQLConf {
.intConf
.createWithDefault(0)

val INFER_NESTED_STRUCT_BY_MAP = buildConf("spark.sql.pyspark.inferNestedStructByMap")
.internal()
.doc("When set to false, inferring the nested struct by StructType. MapType is default.")
val INFER_NESTED_DICT_AS_STRUCT = buildConf("spark.sql.pyspark.inferNestedDictAsStruct.enabled")
.doc("When set to true, infers the nested dict as a struct. By default, it infers it as a map")
.version("3.2.0")
.booleanConf
.createWithDefault(true)
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
Expand Down Expand Up @@ -4047,7 +4046,7 @@ class SQLConf extends Serializable with Logging {

def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)

def inferNestedStructByMap: Boolean = getConf(SQLConf.INFER_NESTED_STRUCT_BY_MAP)
def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)

/** ********************** SQLConf functionality methods ************ */

Expand Down

0 comments on commit 0ce96fa

Please sign in to comment.