Skip to content

Commit

Permalink
to_replace support dict, value support single value, and add full tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed May 12, 2015
1 parent 9e232e7 commit 4a148f7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
24 changes: 14 additions & 10 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,13 @@ def replace(self, to_replace, value, subset=None):
:param to_replace: int, long, float, string, or list.
Value to be replaced.
The replacement value must be an int, long, float, or string.
If the value is a dict, then `value` is ignored and `to_replace` must be a
mapping from column name (string) to replacement value. The value to be
replaced must be an int, long, float, or string.
:param value: int, long, float, string, or list.
Value to use to replace holes.
The replacement value must be an int, long, float, or string.
The replacement value must be an int, long, float, or string. If `value` is a
list or tuple, `value` should be of the same length with `to_replace`.
:param subset: optional list of column names to consider.
Columns specified in subset that do not have matching data type are ignored.
For example, if `value` is a string, and subset contains a non-string column,
Expand All @@ -961,21 +964,18 @@ def replace(self, to_replace, value, subset=None):
|null| null|null|
+----+------+----+
"""
if not isinstance(to_replace, (float, int, long, basestring, list, tuple)):
raise ValueError("to_replace should be a float, int, long, string, list, or tuple")
if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
raise ValueError(
"to_replace should be a float, int, long, string, list, tuple, or dict")

if not isinstance(value, (float, int, long, basestring, list, tuple)):
raise ValueError("value should be a float, int, long, string, list, or tuple")

if isinstance(to_replace, dict) and not isinstance(value, (list, tuple, dict)):
raise TypeError("")
rep_dict = dict()

if isinstance(to_replace, (float, int, long, basestring)):
to_replace = [to_replace]

if isinstance(value, (float, int, long, basestring)):
value = [value]

if isinstance(to_replace, tuple):
to_replace = list(to_replace)

Expand All @@ -985,8 +985,12 @@ def replace(self, to_replace, value, subset=None):
if isinstance(to_replace, list) and isinstance(value, list):
if len(to_replace) != len(value):
raise ValueError("to_replace and value lists should be of the same length")
rep_dict = dict(zip(to_replace, value))
elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
rep_dict = {tr: value for tr in to_replace}
elif isinstance(to_replace, dict):
rep_dict = to_replace

rep_dict = dict(zip(to_replace, value))
if subset is None:
return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
elif isinstance(subset, basestring):
Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,54 @@ def test_bitwise_operations(self):
result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
self.assertEqual(~75, result['~b'])

def test_replace(self):
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("height", DoubleType(), True)])

# replace with int
row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 20.0)

# replace with double
row = self.sqlCtx.createDataFrame(
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
self.assertEqual(row.age, 82)
self.assertEqual(row.height, 82.1)

# replace with string
row = self.sqlCtx.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
self.assertEqual(row.name, u"Ann")
self.assertEqual(row.age, 10)

# replace with subset specified by a string of a column name w/ actual change
row = self.sqlCtx.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
self.assertEqual(row.age, 20)

# replace with subset specified by a string of a column name w/o actual change
row = self.sqlCtx.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
self.assertEqual(row.age, 10)

# replace with subset specified with one column replaced, another column not in subset
# stays unchanged.
row = self.sqlCtx.createDataFrame(
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
self.assertEqual(row.name, u'Alice')
self.assertEqual(row.age, 20)
self.assertEqual(row.height, 10.0)

# replace with subset specified but no column will be replaced
row = self.sqlCtx.createDataFrame(
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
self.assertEqual(row.name, u'Alice')
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)


class HiveContextSQLTests(ReusedPySparkTestCase):

Expand Down

0 comments on commit 4a148f7

Please sign in to comment.