Skip to content

Commit

Permalink
Fix dt_day_of_week/day_of_year, str_cat/extract/partition/replace/rpa…
Browse files Browse the repository at this point in the history
…rtition (modin-project#51)

* Fix dt_day_of_week/day_of_year, str_partition/replace/rpartition

* Fix str_extract
  • Loading branch information
helmeleegy authored Feb 2, 2023
1 parent 3a2e2c9 commit f7a31ab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
2 changes: 2 additions & 0 deletions modin/core/execution/client/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,11 @@ def forwarding_method(self, by, *args, **kwargs):
"dt_minute",
"dt_hour",
"dt_day",
"dt_day_of_week",
"dt_dayofweek",
"dt_weekday",
"dt_day_name",
"dt_day_of_year",
"dt_dayofyear",
"dt_week",
"dt_weekofyear",
Expand Down
53 changes: 44 additions & 9 deletions modin/pandas/series_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,14 @@ def casefold(self):
def cat(self, others=None, sep=None, na_rep=None, join=None):
if isinstance(others, Series):
others = others._to_pandas()
return self._default_to_pandas(
pandas.Series.str.cat, others=others, sep=sep, na_rep=na_rep, join=join
data = Series(query_compiler=self._query_compiler)
return data._reduce_dimension(
self._query_compiler.str_cat(
others=others,
sep=sep,
na_rep=na_rep,
join=join
)
)

def decode(self, encoding, errors="strict"):
Expand Down Expand Up @@ -307,6 +313,13 @@ def match(self, pat, case=True, flags=0, na=np.NaN):
)

def extract(self, pat, flags=0, expand=True):
import re
n = re.compile(pat).groups
if expand or n > 1:
from .dataframe import DataFrame
return DataFrame(
query_compiler=self._query_compiler.str_extract(pat, flags, expand)
)
return Series(
query_compiler=self._query_compiler.str_extract(pat, flags, expand)
)
Expand All @@ -329,9 +342,20 @@ def lstrip(self, to_strip=None):
def partition(self, sep=" ", expand=True):
if sep is not None and len(sep) == 0:
raise ValueError("empty separator")
return Series(
query_compiler=self._query_compiler.str_partition(sep=sep, expand=expand)
)

if expand:
from .dataframe import DataFrame
return DataFrame(
query_compiler=self._query_compiler.str_partition(
sep=sep, expand=expand
)
)
else:
return Series(
query_compiler=self._query_compiler.str_partition(
sep=sep, expand=expand
)
)

def removeprefix(self, prefix):
return Series(query_compiler=self._query_compiler.str_removeprefix(prefix))
Expand All @@ -343,11 +367,22 @@ def repeat(self, repeats):
return Series(query_compiler=self._query_compiler.str_repeat(repeats))

def rpartition(self, sep=" ", expand=True):
return Series(
query_compiler=self._query_compiler.str_rpartition(
sep=sep, expand=expand
if sep is not None and len(sep) == 0:
raise ValueError("empty separator")

if expand:
from .dataframe import DataFrame
return DataFrame(
query_compiler=self._query_compiler.str_rpartition(
sep=sep, expand=expand
)
)
else:
return Series(
query_compiler=self._query_compiler.str_rpartition(
sep=sep, expand=expand
)
)
)

def lower(self):
return Series(query_compiler=self._query_compiler.str_lower())
Expand Down

0 comments on commit f7a31ab

Please sign in to comment.