Skip to content

Commit

Permalink
FIX-#2362: fix key handling in 'Series.__setitem__' (#2731)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev authored Feb 15, 2021
1 parent 04cd912 commit a2ecf31
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
15 changes: 10 additions & 5 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,16 @@ def __round__(self, decimals=0):
)

def __setitem__(self, key, value):
if key not in self.keys():
raise KeyError(key)
self._create_or_update_from_compiler(
self._query_compiler.setitem(1, key, value), inplace=True
)
if isinstance(key, slice) and (
isinstance(key.start, int) or isinstance(key.stop, int)
):
# There could be two type of slices:
# - Location based slice (1:5)
# - Labels based slice ("a":"e")
# For location based slice we're going to `iloc`, since `loc` can't manage it.
self.iloc[key] = value
else:
self.loc[key] = value

def __sub__(self, right):
return self.sub(right)
Expand Down
26 changes: 21 additions & 5 deletions modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def inter_df_math_helper_one_side(modin_series, pandas_series, op):
pass


def create_test_series(vals, sort=False):
def create_test_series(vals, sort=False, **kwargs):
if isinstance(vals, dict):
modin_series = pd.Series(vals[next(iter(vals.keys()))])
pandas_series = pandas.Series(vals[next(iter(vals.keys()))])
modin_series = pd.Series(vals[next(iter(vals.keys()))], **kwargs)
pandas_series = pandas.Series(vals[next(iter(vals.keys()))], **kwargs)
else:
modin_series = pd.Series(vals)
pandas_series = pandas.Series(vals)
modin_series = pd.Series(vals, **kwargs)
pandas_series = pandas.Series(vals, **kwargs)
if sort:
modin_series = modin_series.sort_values().reset_index(drop=True)
pandas_series = pandas_series.sort_values().reset_index(drop=True)
Expand Down Expand Up @@ -526,6 +526,22 @@ def test___setitem__(data):
df_equals(modin_series, pandas_series)


@pytest.mark.parametrize(
"key",
[
pytest.param(slice(1, 3), id="numeric_slice"),
pytest.param(slice("a", "c"), id="index_based_slice"),
pytest.param(["a", "c", "e"], id="list_of_labels"),
pytest.param([True, False, True, False, True], id="boolean_mask"),
],
)
def test___setitem___non_hashable(key):
md_sr, pd_sr = create_test_series([1, 2, 3, 4, 5], index=["a", "b", "c", "d", "e"])
md_sr[key] = 10
pd_sr[key] = 10
df_equals(md_sr, pd_sr)


@pytest.mark.parametrize("data", test_data_values, ids=test_data_keys)
def test___sizeof__(data):
modin_series, pandas_series = create_test_series(data)
Expand Down

0 comments on commit a2ecf31

Please sign in to comment.