diff --git a/ci/code_checks.sh b/ci/code_checks.sh index e13738b98833aa..45be38b40b6588 100755 --- a/ci/code_checks.sh +++ b/ci/code_checks.sh @@ -266,6 +266,10 @@ if [[ -z "$CHECK" || "$CHECK" == "doctests" ]]; then -k"-from_arrays -from_breaks -from_intervals -from_tuples -set_closed -to_tuples -interval_range" RET=$(($RET + $?)) ; echo $MSG "DONE" + MSG='Doctests arrays/string_.py' ; echo $MSG + pytest -q --doctest-modules pandas/core/arrays/string_.py + RET=$(($RET + $?)) ; echo $MSG "DONE" + fi ### DOCSTRINGS ### diff --git a/doc/source/getting_started/basics.rst b/doc/source/getting_started/basics.rst index 802ffadf2a81ef..36a7166f350e5b 100644 --- a/doc/source/getting_started/basics.rst +++ b/doc/source/getting_started/basics.rst @@ -986,7 +986,7 @@ not noted for a particular column will be ``NaN``: tsdf.agg({'A': ['mean', 'min'], 'B': 'sum'}) -.. _basics.aggregation.mixed_dtypes: +.. _basics.aggregation.mixed_string: Mixed dtypes ++++++++++++ @@ -1704,7 +1704,8 @@ built-in string methods. For example: .. ipython:: python - s = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat']) + s = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat'], + dtype="string") s.str.lower() Powerful pattern-matching methods are provided as well, but note that @@ -1712,6 +1713,12 @@ pattern-matching generally uses `regular expressions `__ by default (and in some cases always uses them). +.. note:: + + Prior to pandas 1.0, string methods were only available on ``object`` -dtype + ``Series``. Pandas 1.0 added the :class:`StringDtype` which is dedicated + to strings. See :ref:`text.types` for more. + Please see :ref:`Vectorized String Methods ` for a complete description. @@ -1925,9 +1932,15 @@ period (time spans) :class:`PeriodDtype` :class:`Period` :class:`arrays. sparse :class:`SparseDtype` (none) :class:`arrays.SparseArray` :ref:`sparse` intervals :class:`IntervalDtype` :class:`Interval` :class:`arrays.IntervalArray` :ref:`advanced.intervalindex` nullable integer :class:`Int64Dtype`, ... (none) :class:`arrays.IntegerArray` :ref:`integer_na` +Strings :class:`StringDtype` :class:`str` :class:`arrays.StringArray` :ref:`text` =================== ========================= ================== ============================= ============================= -Pandas uses the ``object`` dtype for storing strings. +Pandas has two ways to store strings. + +1. ``object`` dtype, which can hold any Python object, including strings. +2. :class:`StringDtype`, which is dedicated to strings. + +Generally, we recommend using :class:`StringDtype`. See :ref:`text.types` fore more. Finally, arbitrary objects may be stored using the ``object`` dtype, but should be avoided to the extent possible (for performance and interoperability with diff --git a/doc/source/reference/arrays.rst b/doc/source/reference/arrays.rst index 7f464bf952bfbf..0c435e06ac57f6 100644 --- a/doc/source/reference/arrays.rst +++ b/doc/source/reference/arrays.rst @@ -24,6 +24,7 @@ Intervals :class:`IntervalDtype` :class:`Interval` :ref:`api.array Nullable Integer :class:`Int64Dtype`, ... (none) :ref:`api.arrays.integer_na` Categorical :class:`CategoricalDtype` (none) :ref:`api.arrays.categorical` Sparse :class:`SparseDtype` (none) :ref:`api.arrays.sparse` +Strings :class:`StringDtype` :class:`str` :ref:`api.arrays.string` =================== ========================= ================== ============================= Pandas and third-party libraries can extend NumPy's type system (see :ref:`extending.extension-types`). @@ -460,6 +461,29 @@ and methods if the :class:`Series` contains sparse values. See :ref:`api.series.sparse` for more. +.. _api.arrays.string: + +Text data +--------- + +When working with text data, where each valid element is a string or missing, +we recommend using :class:`StringDtype` (with the alias ``"string"``). + +.. autosummary:: + :toctree: api/ + :template: autosummary/class_without_autosummary.rst + + arrays.StringArray + +.. autosummary:: + :toctree: api/ + :template: autosummary/class_without_autosummary.rst + + StringDtype + +The ``Series.str`` accessor is available for ``Series`` backed by a :class:`arrays.StringArray`. +See :ref:`api.series.str` for more. + .. Dtype attributes which are manually listed in their docstrings: including .. it here to make sure a docstring page is built for them @@ -471,4 +495,4 @@ and methods if the :class:`Series` contains sparse values. See DatetimeTZDtype.unit DatetimeTZDtype.tz PeriodDtype.freq - IntervalDtype.subtype \ No newline at end of file + IntervalDtype.subtype diff --git a/doc/source/user_guide/text.rst b/doc/source/user_guide/text.rst index acb5810e5252a9..789ff2a65355b5 100644 --- a/doc/source/user_guide/text.rst +++ b/doc/source/user_guide/text.rst @@ -6,8 +6,71 @@ Working with text data ====================== +.. _text.types: + +Text Data Types +--------------- + +.. versionadded:: 1.0.0 + +There are two main ways to store text data + +1. ``object`` -dtype NumPy array. +2. :class:`StringDtype` extension type. + +We recommend using :class:`StringDtype` to store text data. + +Prior to pandas 1.0, ``object`` dtype was the only option. This was unfortunate +for many reasons: + +1. You can accidentally store a *mixture* of strings and non-strings in an + ``object`` dtype array. It's better to have a dedicated dtype. +2. ``object`` dtype breaks dtype-specific operations like :meth:`DataFrame.select_dtypes`. + There isn't a clear way to select *just* text while excluding non-text + but still object-dtype columns. +3. When reading code, the contents of an ``object`` dtype array is less clear + than ``'string'``. + +Currently, the performance of ``object`` dtype arrays of strings and +:class:`arrays.StringArray` are about the same. We expect future enhancements +to significantly increase the performance and lower the memory overhead of +:class:`~arrays.StringArray`. + +.. warning:: + + ``StringArray`` is currently considered experimental. The implementation + and parts of the API may change without warning. + +For backwards-compatibility, ``object`` dtype remains the default type we +infer a list of strings to + +.. ipython:: python + + pd.Series(['a', 'b', 'c']) + +To explicitly request ``string`` dtype, specify the ``dtype`` + +.. ipython:: python + + pd.Series(['a', 'b', 'c'], dtype="string") + pd.Series(['a', 'b', 'c'], dtype=pd.StringDtype()) + +Or ``astype`` after the ``Series`` or ``DataFrame`` is created + +.. ipython:: python + + s = pd.Series(['a', 'b', 'c']) + s + s.astype("string") + +Everything that follows in the rest of this document applies equally to +``string`` and ``object`` dtype. + .. _text.string_methods: +String Methods +-------------- + Series and Index are equipped with a set of string processing methods that make it easy to operate on each element of the array. Perhaps most importantly, these methods exclude missing/NA values automatically. These are @@ -16,7 +79,8 @@ the equivalent (scalar) built-in string methods: .. ipython:: python - s = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat']) + s = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat'], + dtype="string") s.str.lower() s.str.upper() s.str.len() @@ -90,7 +154,7 @@ Methods like ``split`` return a Series of lists: .. ipython:: python - s2 = pd.Series(['a_b_c', 'c_d_e', np.nan, 'f_g_h']) + s2 = pd.Series(['a_b_c', 'c_d_e', np.nan, 'f_g_h'], dtype="string") s2.str.split('_') Elements in the split lists can be accessed using ``get`` or ``[]`` notation: @@ -106,6 +170,9 @@ It is easy to expand this to return a DataFrame using ``expand``. s2.str.split('_', expand=True) +When original ``Series`` has :class:`StringDtype`, the output columns will all +be :class:`StringDtype` as well. + It is also possible to limit the number of splits: .. ipython:: python @@ -125,7 +192,8 @@ i.e., from the end of the string to the beginning of the string: .. ipython:: python s3 = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', - '', np.nan, 'CABA', 'dog', 'cat']) + '', np.nan, 'CABA', 'dog', 'cat'], + dtype="string") s3 s3.str.replace('^.a|dog', 'XX-XX ', case=False) @@ -136,7 +204,7 @@ following code will cause trouble because of the regular expression meaning of .. ipython:: python # Consider the following badly formatted financial data - dollars = pd.Series(['12', '-$10', '$10,000']) + dollars = pd.Series(['12', '-$10', '$10,000'], dtype="string") # This does what you'd naively expect: dollars.str.replace('$', '') @@ -174,7 +242,8 @@ positional argument (a regex object) and return a string. def repl(m): return m.group(0)[::-1] - pd.Series(['foo 123', 'bar baz', np.nan]).str.replace(pat, repl) + pd.Series(['foo 123', 'bar baz', np.nan], + dtype="string").str.replace(pat, repl) # Using regex groups pat = r"(?P\w+) (?P\w+) (?P\w+)" @@ -182,7 +251,8 @@ positional argument (a regex object) and return a string. def repl(m): return m.group('two').swapcase() - pd.Series(['Foo Bar Baz', np.nan]).str.replace(pat, repl) + pd.Series(['Foo Bar Baz', np.nan], + dtype="string").str.replace(pat, repl) .. versionadded:: 0.20.0 @@ -221,7 +291,7 @@ The content of a ``Series`` (or ``Index``) can be concatenated: .. ipython:: python - s = pd.Series(['a', 'b', 'c', 'd']) + s = pd.Series(['a', 'b', 'c', 'd'], dtype="string") s.str.cat(sep=',') If not specified, the keyword ``sep`` for the separator defaults to the empty string, ``sep=''``: @@ -234,7 +304,7 @@ By default, missing values are ignored. Using ``na_rep``, they can be given a re .. ipython:: python - t = pd.Series(['a', 'b', np.nan, 'd']) + t = pd.Series(['a', 'b', np.nan, 'd'], dtype="string") t.str.cat(sep=',') t.str.cat(sep=',', na_rep='-') @@ -279,7 +349,8 @@ the ``join``-keyword. .. ipython:: python :okwarning: - u = pd.Series(['b', 'd', 'a', 'c'], index=[1, 3, 0, 2]) + u = pd.Series(['b', 'd', 'a', 'c'], index=[1, 3, 0, 2], + dtype="string") s u s.str.cat(u) @@ -295,7 +366,8 @@ In particular, alignment also means that the different lengths do not need to co .. ipython:: python - v = pd.Series(['z', 'a', 'b', 'd', 'e'], index=[-1, 0, 1, 3, 4]) + v = pd.Series(['z', 'a', 'b', 'd', 'e'], index=[-1, 0, 1, 3, 4], + dtype="string") s v s.str.cat(v, join='left', na_rep='-') @@ -351,7 +423,8 @@ of the string, the result will be a ``NaN``. .. ipython:: python s = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, - 'CABA', 'dog', 'cat']) + 'CABA', 'dog', 'cat'], + dtype="string") s.str[0] s.str[1] @@ -382,7 +455,8 @@ DataFrame with one column per group. .. ipython:: python - pd.Series(['a1', 'b2', 'c3']).str.extract(r'([ab])(\d)', expand=False) + pd.Series(['a1', 'b2', 'c3'], + dtype="string").str.extract(r'([ab])(\d)', expand=False) Elements that do not match return a row filled with ``NaN``. Thus, a Series of messy strings can be "converted" into a like-indexed Series @@ -395,14 +469,16 @@ Named groups like .. ipython:: python - pd.Series(['a1', 'b2', 'c3']).str.extract(r'(?P[ab])(?P\d)', - expand=False) + pd.Series(['a1', 'b2', 'c3'], + dtype="string").str.extract(r'(?P[ab])(?P\d)', + expand=False) and optional groups like .. ipython:: python - pd.Series(['a1', 'b2', '3']).str.extract(r'([ab])?(\d)', expand=False) + pd.Series(['a1', 'b2', '3'], + dtype="string").str.extract(r'([ab])?(\d)', expand=False) can also be used. Note that any capture group names in the regular expression will be used for column names; otherwise capture group @@ -413,20 +489,23 @@ with one column if ``expand=True``. .. ipython:: python - pd.Series(['a1', 'b2', 'c3']).str.extract(r'[ab](\d)', expand=True) + pd.Series(['a1', 'b2', 'c3'], + dtype="string").str.extract(r'[ab](\d)', expand=True) It returns a Series if ``expand=False``. .. ipython:: python - pd.Series(['a1', 'b2', 'c3']).str.extract(r'[ab](\d)', expand=False) + pd.Series(['a1', 'b2', 'c3'], + dtype="string").str.extract(r'[ab](\d)', expand=False) Calling on an ``Index`` with a regex with exactly one capture group returns a ``DataFrame`` with one column if ``expand=True``. .. ipython:: python - s = pd.Series(["a1", "b2", "c3"], ["A11", "B22", "C33"]) + s = pd.Series(["a1", "b2", "c3"], ["A11", "B22", "C33"], + dtype="string") s s.index.str.extract("(?P[a-zA-Z])", expand=True) @@ -471,7 +550,8 @@ Unlike ``extract`` (which returns only the first match), .. ipython:: python - s = pd.Series(["a1a2", "b1", "c1"], index=["A", "B", "C"]) + s = pd.Series(["a1a2", "b1", "c1"], index=["A", "B", "C"], + dtype="string") s two_groups = '(?P[a-z])(?P[0-9])' s.str.extract(two_groups, expand=True) @@ -489,7 +569,7 @@ When each subject string in the Series has exactly one match, .. ipython:: python - s = pd.Series(['a3', 'b3', 'c2']) + s = pd.Series(['a3', 'b3', 'c2'], dtype="string") s then ``extractall(pat).xs(0, level='match')`` gives the same result as @@ -510,7 +590,7 @@ same result as a ``Series.str.extractall`` with a default index (starts from 0). pd.Index(["a1a2", "b1", "c1"]).str.extractall(two_groups) - pd.Series(["a1a2", "b1", "c1"]).str.extractall(two_groups) + pd.Series(["a1a2", "b1", "c1"], dtype="string").str.extractall(two_groups) Testing for Strings that match or contain a pattern @@ -521,13 +601,15 @@ You can check whether elements contain a pattern: .. ipython:: python pattern = r'[0-9][a-z]' - pd.Series(['1', '2', '3a', '3b', '03c']).str.contains(pattern) + pd.Series(['1', '2', '3a', '3b', '03c'], + dtype="string").str.contains(pattern) Or whether elements match a pattern: .. ipython:: python - pd.Series(['1', '2', '3a', '3b', '03c']).str.match(pattern) + pd.Series(['1', '2', '3a', '3b', '03c'], + dtype="string").str.match(pattern) The distinction between ``match`` and ``contains`` is strictness: ``match`` relies on strict ``re.match``, while ``contains`` relies on ``re.search``. @@ -537,7 +619,8 @@ an extra ``na`` argument so missing values can be considered True or False: .. ipython:: python - s4 = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat']) + s4 = pd.Series(['A', 'B', 'C', 'Aaba', 'Baca', np.nan, 'CABA', 'dog', 'cat'], + dtype="string") s4.str.contains('A', na=False) .. _text.indicator: @@ -550,7 +633,7 @@ For example if they are separated by a ``'|'``: .. ipython:: python - s = pd.Series(['a', 'a|b', np.nan, 'a|c']) + s = pd.Series(['a', 'a|b', np.nan, 'a|c'], dtype="string") s.str.get_dummies(sep='|') String ``Index`` also supports ``get_dummies`` which returns a ``MultiIndex``. diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index ceb1247e176dcb..d8dc8ae68c347c 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -50,14 +50,56 @@ including other versions of pandas. Enhancements ~~~~~~~~~~~~ -- :meth:`DataFrame.to_string` added the ``max_colwidth`` parameter to control when wide columns are truncated (:issue:`9784`) -- +.. _whatsnew_100.string: + +Dedicated string data type +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We've added :class:`StringDtype`, an extension type dedicated to string data. +Previously, strings were typically stored in object-dtype NumPy arrays. + +.. warning:: + + ``StringDtype`` and is currently considered experimental. The implementation + and parts of the API may change without warning. + +The text extension type solves several issues with object-dtype NumPy arrays: + +1. You can accidentally store a *mixture* of strings and non-strings in an + ``object`` dtype array. A ``StringArray`` can only store strings. +2. ``object`` dtype breaks dtype-specific operations like :meth:`DataFrame.select_dtypes`. + There isn't a clear way to select *just* text while excluding non-text, + but still object-dtype columns. +3. When reading code, the contents of an ``object`` dtype array is less clear + than ``string``. + + +.. ipython:: python + + pd.Series(['abc', None, 'def'], dtype=pd.StringDtype()) + +You can use the alias ``"string"`` as well. + +.. ipython:: python + + s = pd.Series(['abc', None, 'def'], dtype="string") + s + +The usual string accessor methods work. Where appropriate, the return type +of the Series or columns of a DataFrame will also have string dtype. + + s.str.upper() + s.str.split('b', expand=True).dtypes + +We recommend explicitly using the ``string`` data type when working with strings. +See :ref:`text.types` for more. .. _whatsnew_1000.enhancements.other: Other enhancements ^^^^^^^^^^^^^^^^^^ +- :meth:`DataFrame.to_string` added the ``max_colwidth`` parameter to control when wide columns are truncated (:issue:`9784`) - :meth:`MultiIndex.from_product` infers level names from inputs if not explicitly provided (:issue:`27292`) - :meth:`DataFrame.to_latex` now accepts ``caption`` and ``label`` arguments (:issue:`25436`) - The :ref:`integer dtype ` with support for missing values can now be converted to diff --git a/pandas/__init__.py b/pandas/__init__.py index 6d0c55a45ed46c..5d163e411c0acb 100644 --- a/pandas/__init__.py +++ b/pandas/__init__.py @@ -66,6 +66,7 @@ PeriodDtype, IntervalDtype, DatetimeTZDtype, + StringDtype, # missing isna, isnull, diff --git a/pandas/arrays/__init__.py b/pandas/arrays/__init__.py index db01f2a0c674f6..9870b5bed076d8 100644 --- a/pandas/arrays/__init__.py +++ b/pandas/arrays/__init__.py @@ -11,6 +11,7 @@ PandasArray, PeriodArray, SparseArray, + StringArray, TimedeltaArray, ) @@ -22,5 +23,6 @@ "PandasArray", "PeriodArray", "SparseArray", + "StringArray", "TimedeltaArray", ] diff --git a/pandas/core/api.py b/pandas/core/api.py index bd2a57a15bdd2b..04f2f84c92a157 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -10,6 +10,7 @@ ) from pandas.core.dtypes.missing import isna, isnull, notna, notnull +# TODO: Remove get_dummies import when statsmodels updates #18264 from pandas.core.algorithms import factorize, unique, value_counts from pandas.core.arrays import Categorical from pandas.core.arrays.integer import ( @@ -22,12 +23,9 @@ UInt32Dtype, UInt64Dtype, ) +from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import array - from pandas.core.groupby import Grouper, NamedAgg - -# DataFrame needs to be imported after NamedAgg to avoid a circular import -from pandas.core.frame import DataFrame # isort:skip from pandas.core.index import ( CategoricalIndex, DatetimeIndex, @@ -47,9 +45,7 @@ from pandas.core.indexes.period import Period, period_range from pandas.core.indexes.timedeltas import Timedelta, timedelta_range from pandas.core.indexing import IndexSlice -from pandas.core.reshape.reshape import ( - get_dummies, -) # TODO: Remove get_dummies import when statsmodels updates #18264 +from pandas.core.reshape.reshape import get_dummies from pandas.core.series import Series from pandas.core.tools.datetimes import to_datetime from pandas.core.tools.numeric import to_numeric @@ -57,3 +53,6 @@ from pandas.io.formats.format import set_eng_float_format from pandas.tseries.offsets import DateOffset + +# DataFrame needs to be imported after NamedAgg to avoid a circular import +from pandas.core.frame import DataFrame # isort:skip diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index 5c83ed8cf5e241..868118bac6a7b8 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -10,4 +10,5 @@ from .numpy_ import PandasArray, PandasDtype # noqa: F401 from .period import PeriodArray, period_array # noqa: F401 from .sparse import SparseArray # noqa: F401 +from .string_ import StringArray # noqa: F401 from .timedeltas import TimedeltaArray # noqa: F401 diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 32da0199e28f88..bf7404e8997c6b 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -10,7 +10,7 @@ from pandas.core.dtypes.dtypes import ExtensionDtype from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries -from pandas.core.dtypes.inference import is_array_like, is_list_like +from pandas.core.dtypes.inference import is_array_like from pandas.core.dtypes.missing import isna from pandas import compat @@ -229,13 +229,15 @@ def __getitem__(self, item): def __setitem__(self, key, value): value = extract_array(value, extract_numpy=True) - if not lib.is_scalar(key) and is_list_like(key): + scalar_key = lib.is_scalar(key) + scalar_value = lib.is_scalar(value) + + if not scalar_key and scalar_value: key = np.asarray(key) - if not lib.is_scalar(value): - value = np.asarray(value) + if not scalar_value: + value = np.asarray(value, dtype=self._ndarray.dtype) - value = np.asarray(value, dtype=self._ndarray.dtype) self._ndarray[key] = value def __len__(self) -> int: diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py new file mode 100644 index 00000000000000..87649ac6511274 --- /dev/null +++ b/pandas/core/arrays/string_.py @@ -0,0 +1,281 @@ +import operator +from typing import TYPE_CHECKING, Type + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.common import pandas_dtype +from pandas.core.dtypes.dtypes import register_extension_dtype +from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries +from pandas.core.dtypes.inference import is_array_like + +from pandas import compat +from pandas.core import ops +from pandas.core.arrays import PandasArray +from pandas.core.construction import extract_array +from pandas.core.missing import isna + +if TYPE_CHECKING: + from pandas._typing import Scalar + + +@register_extension_dtype +class StringDtype(ExtensionDtype): + """ + Extension dtype for string data. + + .. versionadded:: 1.0.0 + + .. warning:: + + StringDtype is considered experimental. The implementation and + parts of the API may change without warning. + + In particular, StringDtype.na_value may change to no longer be + ``numpy.nan``. + + Attributes + ---------- + None + + Methods + ------- + None + + Examples + -------- + >>> pd.StringDtype() + StringDtype + """ + + @property + def na_value(self) -> "Scalar": + """ + StringDtype uses :attr:`numpy.nan` as the missing NA value. + + .. warning:: + + `na_value` may change in a future release. + """ + return np.nan + + @property + def type(self) -> Type: + return str + + @property + def name(self) -> str: + """ + The alias for StringDtype is ``'string'``. + """ + return "string" + + @classmethod + def construct_from_string(cls, string: str) -> ExtensionDtype: + if string == "string": + return cls() + return super().construct_from_string(string) + + @classmethod + def construct_array_type(cls) -> "Type[StringArray]": + return StringArray + + def __repr__(self) -> str: + return "StringDtype" + + +class StringArray(PandasArray): + """ + Extension array for string data. + + .. versionadded:: 1.0.0 + + .. warning:: + + StringArray is considered experimental. The implementation and + parts of the API may change without warning. + + In particular, the NA value used may change to no longer be + ``numpy.nan``. + + Parameters + ---------- + values : array-like + The array of data. + + .. warning:: + + Currently, this expects an object-dtype ndarray + where the elements are Python strings. This may + change without warning in the future. + copy : bool, default False + Whether to copy the array of data. + + Attributes + ---------- + None + + Methods + ------- + None + + See Also + -------- + Series.str + The string methods are available on Series backed by + a StringArray. + + Examples + -------- + >>> pd.array(['This is', 'some text', None, 'data.'], dtype="string") + + ['This is', 'some text', nan, 'data.'] + Length: 4, dtype: string + + Unlike ``object`` dtype arrays, ``StringArray`` doesn't allow non-string + values. + + >>> pd.array(['1', 1], dtype="string") + Traceback (most recent call last): + ... + ValueError: StringArray requires an object-dtype ndarray of strings. + """ + + # undo the PandasArray hack + _typ = "extension" + + def __init__(self, values, copy=False): + values = extract_array(values) + skip_validation = isinstance(values, type(self)) + + super().__init__(values, copy=copy) + self._dtype = StringDtype() + if not skip_validation: + self._validate() + + def _validate(self): + """Validate that we only store NA or strings.""" + if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): + raise ValueError( + "StringArray requires a sequence of strings or missing values." + ) + if self._ndarray.dtype != "object": + raise ValueError( + "StringArray requires a sequence of strings. Got " + "'{}' dtype instead.".format(self._ndarray.dtype) + ) + + @classmethod + def _from_sequence(cls, scalars, dtype=None, copy=False): + if dtype: + assert dtype == "string" + result = super()._from_sequence(scalars, dtype=object, copy=copy) + # convert None to np.nan + # TODO: it would be nice to do this in _validate / lib.is_string_array + # We are already doing a scan over the values there. + result[result.isna()] = np.nan + return result + + @classmethod + def _from_sequence_of_strings(cls, strings, dtype=None, copy=False): + return cls._from_sequence(strings, dtype=dtype, copy=copy) + + def __setitem__(self, key, value): + value = extract_array(value, extract_numpy=True) + if isinstance(value, type(self)): + # extract_array doesn't extract PandasArray subclasses + value = value._ndarray + + scalar_key = lib.is_scalar(key) + scalar_value = lib.is_scalar(value) + if scalar_key and not scalar_value: + raise ValueError("setting an array element with a sequence.") + + # validate new items + if scalar_value: + if scalar_value is None: + value = np.nan + elif not (isinstance(value, str) or np.isnan(value)): + raise ValueError( + "Cannot set non-string value '{}' into a StringArray.".format(value) + ) + else: + if not is_array_like(value): + value = np.asarray(value, dtype=object) + if len(value) and not lib.is_string_array(value, skipna=True): + raise ValueError("Must provide strings.") + + super().__setitem__(key, value) + + def fillna(self, value=None, method=None, limit=None): + # TODO: validate dtype + return super().fillna(value, method, limit) + + def astype(self, dtype, copy=True): + dtype = pandas_dtype(dtype) + if isinstance(dtype, StringDtype): + if copy: + return self.copy() + return self + return super().astype(dtype, copy) + + def _reduce(self, name, skipna=True, **kwargs): + raise TypeError("Cannot perform reduction '{}' with string dtype".format(name)) + + def value_counts(self, dropna=False): + from pandas import value_counts + + return value_counts(self._ndarray, dropna=dropna) + + # Overrride parent because we have different return types. + @classmethod + def _create_arithmetic_method(cls, op): + def method(self, other): + if isinstance(other, (ABCIndexClass, ABCSeries, ABCDataFrame)): + return NotImplemented + + elif isinstance(other, cls): + other = other._ndarray + + mask = isna(self) | isna(other) + valid = ~mask + + if not lib.is_scalar(other): + if len(other) != len(self): + # prevent improper broadcasting when other is 2D + raise ValueError( + "Lengths of operands do not match: {} != {}".format( + len(self), len(other) + ) + ) + + other = np.asarray(other) + other = other[valid] + + result = np.empty_like(self._ndarray, dtype="object") + result[mask] = np.nan + result[valid] = op(self._ndarray[valid], other) + + if op.__name__ in {"add", "radd", "mul", "rmul"}: + return StringArray(result) + else: + dtype = "object" if mask.any() else "bool" + return np.asarray(result, dtype=dtype) + + return compat.set_function_name(method, "__{}__".format(op.__name__), cls) + + @classmethod + def _add_arithmetic_ops(cls): + cls.__add__ = cls._create_arithmetic_method(operator.add) + cls.__radd__ = cls._create_arithmetic_method(ops.radd) + + cls.__mul__ = cls._create_arithmetic_method(operator.mul) + cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) + + _create_comparison_method = _create_arithmetic_method + + +StringArray._add_arithmetic_ops() +StringArray._add_comparison_ops() diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index cd87fbef02e4fb..56bfbefdbf248d 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -128,6 +128,7 @@ def isna(obj): def _isna_new(obj): + if is_scalar(obj): return libmissing.checknull(obj) # hack (for now) because MI registers as ndarray diff --git a/pandas/core/strings.py b/pandas/core/strings.py index 25350119f9df50..888d2ae6f94737 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -763,6 +763,16 @@ def f(x): return f +def _result_dtype(arr): + # workaround #27953 + # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails + # when the list of values is empty. + if arr.dtype.name == "string": + return "string" + else: + return object + + def _str_extract_noexpand(arr, pat, flags=0): """ Find groups in each string in the Series using passed regular @@ -817,11 +827,12 @@ def _str_extract_frame(arr, pat, flags=0): result_index = arr.index except AttributeError: result_index = None + dtype = _result_dtype(arr) return DataFrame( [groups_or_na(val) for val in arr], columns=columns, index=result_index, - dtype=object, + dtype=dtype, ) @@ -1019,8 +1030,11 @@ def str_extractall(arr, pat, flags=0): from pandas import MultiIndex index = MultiIndex.from_tuples(index_list, names=arr.index.names + ["match"]) + dtype = _result_dtype(arr) - result = arr._constructor_expanddim(match_list, index=index, columns=columns) + result = arr._constructor_expanddim( + match_list, index=index, columns=columns, dtype=dtype + ) return result @@ -1073,7 +1087,7 @@ def str_get_dummies(arr, sep="|"): for i, t in enumerate(tags): pat = sep + t + sep - dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x) + dummies[:, i] = lib.map_infer(arr.to_numpy(), lambda x: pat in x) return dummies, tags @@ -1858,11 +1872,18 @@ def wrapper(self, *args, **kwargs): return _forbid_nonstring_types -def _noarg_wrapper(f, name=None, docstring=None, forbidden_types=["bytes"], **kargs): +def _noarg_wrapper( + f, + name=None, + docstring=None, + forbidden_types=["bytes"], + returns_string=True, + **kargs +): @forbid_nonstring_types(forbidden_types, name=name) def wrapper(self): result = _na_map(f, self._parent, **kargs) - return self._wrap_result(result) + return self._wrap_result(result, returns_string=returns_string) wrapper.__name__ = f.__name__ if name is None else name if docstring is not None: @@ -1874,22 +1895,28 @@ def wrapper(self): def _pat_wrapper( - f, flags=False, na=False, name=None, forbidden_types=["bytes"], **kwargs + f, + flags=False, + na=False, + name=None, + forbidden_types=["bytes"], + returns_string=True, + **kwargs ): @forbid_nonstring_types(forbidden_types, name=name) def wrapper1(self, pat): result = f(self._parent, pat) - return self._wrap_result(result) + return self._wrap_result(result, returns_string=returns_string) @forbid_nonstring_types(forbidden_types, name=name) def wrapper2(self, pat, flags=0, **kwargs): result = f(self._parent, pat, flags=flags, **kwargs) - return self._wrap_result(result) + return self._wrap_result(result, returns_string=returns_string) @forbid_nonstring_types(forbidden_types, name=name) def wrapper3(self, pat, na=np.nan): result = f(self._parent, pat, na=na) - return self._wrap_result(result) + return self._wrap_result(result, returns_string=returns_string) wrapper = wrapper3 if na else wrapper2 if flags else wrapper1 @@ -1926,6 +1953,7 @@ class StringMethods(NoNewAttributesMixin): def __init__(self, data): self._inferred_dtype = self._validate(data) self._is_categorical = is_categorical_dtype(data) + self._is_string = data.dtype.name == "string" # .values.categories works for both Series/Index self._parent = data.values.categories if self._is_categorical else data @@ -1956,6 +1984,8 @@ def _validate(data): ------- dtype : inferred dtype of data """ + from pandas import StringDtype + if isinstance(data, ABCMultiIndex): raise AttributeError( "Can only use .str accessor with Index, not MultiIndex" @@ -1967,6 +1997,10 @@ def _validate(data): values = getattr(data, "values", data) # Series / Index values = getattr(values, "categories", values) # categorical / normal + # explicitly allow StringDtype + if isinstance(values.dtype, StringDtype): + return "string" + try: inferred_dtype = lib.infer_dtype(values, skipna=True) except ValueError: @@ -1992,7 +2026,13 @@ def __iter__(self): g = self.get(i) def _wrap_result( - self, result, use_codes=True, name=None, expand=None, fill_value=np.nan + self, + result, + use_codes=True, + name=None, + expand=None, + fill_value=np.nan, + returns_string=True, ): from pandas import Index, Series, MultiIndex @@ -2012,6 +2052,15 @@ def _wrap_result( return result assert result.ndim < 3 + # We can be wrapping a string / object / categorical result, in which + # case we'll want to return the same dtype as the input. + # Or we can be wrapping a numeric output, in which case we don't want + # to return a StringArray. + if self._is_string and returns_string: + dtype = "string" + else: + dtype = None + if expand is None: # infer from ndim if expand is not specified expand = result.ndim != 1 @@ -2069,11 +2118,12 @@ def cons_row(x): index = self._orig.index if expand: cons = self._orig._constructor_expanddim - return cons(result, columns=name, index=index) + result = cons(result, columns=name, index=index, dtype=dtype) else: # Must be a Series cons = self._orig._constructor - return cons(result, name=name, index=index) + result = cons(result, name=name, index=index, dtype=dtype) + return result def _get_series_list(self, others): """ @@ -2338,9 +2388,12 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"): # add dtype for case that result is all-NA result = Index(result, dtype=object, name=self._orig.name) else: # Series - result = Series( - result, dtype=object, index=data.index, name=self._orig.name - ) + if is_categorical_dtype(self._orig.dtype): + # We need to infer the new categories. + dtype = None + else: + dtype = self._orig.dtype + result = Series(result, dtype=dtype, index=data.index, name=self._orig.name) return result _shared_docs[ @@ -2479,13 +2532,13 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"): @forbid_nonstring_types(["bytes"]) def split(self, pat=None, n=-1, expand=False): result = str_split(self._parent, pat, n=n) - return self._wrap_result(result, expand=expand) + return self._wrap_result(result, expand=expand, returns_string=expand) @Appender(_shared_docs["str_split"] % {"side": "end", "method": "rsplit"}) @forbid_nonstring_types(["bytes"]) def rsplit(self, pat=None, n=-1, expand=False): result = str_rsplit(self._parent, pat, n=n) - return self._wrap_result(result, expand=expand) + return self._wrap_result(result, expand=expand, returns_string=expand) _shared_docs[ "str_partition" @@ -2586,7 +2639,7 @@ def rsplit(self, pat=None, n=-1, expand=False): def partition(self, sep=" ", expand=True): f = lambda x: x.partition(sep) result = _na_map(f, self._parent) - return self._wrap_result(result, expand=expand) + return self._wrap_result(result, expand=expand, returns_string=expand) @Appender( _shared_docs["str_partition"] @@ -2602,7 +2655,7 @@ def partition(self, sep=" ", expand=True): def rpartition(self, sep=" ", expand=True): f = lambda x: x.rpartition(sep) result = _na_map(f, self._parent) - return self._wrap_result(result, expand=expand) + return self._wrap_result(result, expand=expand, returns_string=expand) @copy(str_get) def get(self, i): @@ -2621,13 +2674,13 @@ def contains(self, pat, case=True, flags=0, na=np.nan, regex=True): result = str_contains( self._parent, pat, case=case, flags=flags, na=na, regex=regex ) - return self._wrap_result(result, fill_value=na) + return self._wrap_result(result, fill_value=na, returns_string=False) @copy(str_match) @forbid_nonstring_types(["bytes"]) def match(self, pat, case=True, flags=0, na=np.nan): result = str_match(self._parent, pat, case=case, flags=flags, na=na) - return self._wrap_result(result, fill_value=na) + return self._wrap_result(result, fill_value=na, returns_string=False) @copy(str_replace) @forbid_nonstring_types(["bytes"]) @@ -2762,13 +2815,14 @@ def slice_replace(self, start=None, stop=None, repl=None): def decode(self, encoding, errors="strict"): # need to allow bytes here result = str_decode(self._parent, encoding, errors) - return self._wrap_result(result) + # TODO: Not sure how to handle this. + return self._wrap_result(result, returns_string=False) @copy(str_encode) @forbid_nonstring_types(["bytes"]) def encode(self, encoding, errors="strict"): result = str_encode(self._parent, encoding, errors) - return self._wrap_result(result) + return self._wrap_result(result, returns_string=False) _shared_docs[ "str_strip" @@ -2869,7 +2923,11 @@ def get_dummies(self, sep="|"): data = self._orig.astype(str) if self._is_categorical else self._parent result, name = str_get_dummies(data, sep) return self._wrap_result( - result, use_codes=(not self._is_categorical), name=name, expand=True + result, + use_codes=(not self._is_categorical), + name=name, + expand=True, + returns_string=False, ) @copy(str_translate) @@ -2878,10 +2936,16 @@ def translate(self, table): result = str_translate(self._parent, table) return self._wrap_result(result) - count = _pat_wrapper(str_count, flags=True, name="count") - startswith = _pat_wrapper(str_startswith, na=True, name="startswith") - endswith = _pat_wrapper(str_endswith, na=True, name="endswith") - findall = _pat_wrapper(str_findall, flags=True, name="findall") + count = _pat_wrapper(str_count, flags=True, name="count", returns_string=False) + startswith = _pat_wrapper( + str_startswith, na=True, name="startswith", returns_string=False + ) + endswith = _pat_wrapper( + str_endswith, na=True, name="endswith", returns_string=False + ) + findall = _pat_wrapper( + str_findall, flags=True, name="findall", returns_string=False + ) @copy(str_extract) @forbid_nonstring_types(["bytes"]) @@ -2929,7 +2993,7 @@ def extractall(self, pat, flags=0): @forbid_nonstring_types(["bytes"]) def find(self, sub, start=0, end=None): result = str_find(self._parent, sub, start=start, end=end, side="left") - return self._wrap_result(result) + return self._wrap_result(result, returns_string=False) @Appender( _shared_docs["find"] @@ -2942,7 +3006,7 @@ def find(self, sub, start=0, end=None): @forbid_nonstring_types(["bytes"]) def rfind(self, sub, start=0, end=None): result = str_find(self._parent, sub, start=start, end=end, side="right") - return self._wrap_result(result) + return self._wrap_result(result, returns_string=False) @forbid_nonstring_types(["bytes"]) def normalize(self, form): @@ -3004,7 +3068,7 @@ def normalize(self, form): @forbid_nonstring_types(["bytes"]) def index(self, sub, start=0, end=None): result = str_index(self._parent, sub, start=start, end=end, side="left") - return self._wrap_result(result) + return self._wrap_result(result, returns_string=False) @Appender( _shared_docs["index"] @@ -3018,7 +3082,7 @@ def index(self, sub, start=0, end=None): @forbid_nonstring_types(["bytes"]) def rindex(self, sub, start=0, end=None): result = str_index(self._parent, sub, start=start, end=end, side="right") - return self._wrap_result(result) + return self._wrap_result(result, returns_string=False) _shared_docs[ "len" @@ -3067,7 +3131,11 @@ def rindex(self, sub, start=0, end=None): dtype: float64 """ len = _noarg_wrapper( - len, docstring=_shared_docs["len"], forbidden_types=None, dtype=int + len, + docstring=_shared_docs["len"], + forbidden_types=None, + dtype=int, + returns_string=False, ) _shared_docs[ @@ -3339,46 +3407,55 @@ def rindex(self, sub, start=0, end=None): lambda x: x.isalnum(), name="isalnum", docstring=_shared_docs["ismethods"] % _doc_args["isalnum"], + returns_string=False, ) isalpha = _noarg_wrapper( lambda x: x.isalpha(), name="isalpha", docstring=_shared_docs["ismethods"] % _doc_args["isalpha"], + returns_string=False, ) isdigit = _noarg_wrapper( lambda x: x.isdigit(), name="isdigit", docstring=_shared_docs["ismethods"] % _doc_args["isdigit"], + returns_string=False, ) isspace = _noarg_wrapper( lambda x: x.isspace(), name="isspace", docstring=_shared_docs["ismethods"] % _doc_args["isspace"], + returns_string=False, ) islower = _noarg_wrapper( lambda x: x.islower(), name="islower", docstring=_shared_docs["ismethods"] % _doc_args["islower"], + returns_string=False, ) isupper = _noarg_wrapper( lambda x: x.isupper(), name="isupper", docstring=_shared_docs["ismethods"] % _doc_args["isupper"], + returns_string=False, ) istitle = _noarg_wrapper( lambda x: x.istitle(), name="istitle", docstring=_shared_docs["ismethods"] % _doc_args["istitle"], + returns_string=False, ) isnumeric = _noarg_wrapper( lambda x: x.isnumeric(), name="isnumeric", docstring=_shared_docs["ismethods"] % _doc_args["isnumeric"], + returns_string=False, ) isdecimal = _noarg_wrapper( lambda x: x.isdecimal(), name="isdecimal", docstring=_shared_docs["ismethods"] % _doc_args["isdecimal"], + returns_string=False, ) @classmethod diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py index 2f24bbd6f0c853..6c50159663574d 100644 --- a/pandas/tests/api/test_api.py +++ b/pandas/tests/api/test_api.py @@ -68,6 +68,7 @@ class TestPDApi(Base): "Series", "SparseArray", "SparseDtype", + "StringDtype", "Timedelta", "TimedeltaIndex", "Timestamp", diff --git a/pandas/tests/arrays/string_/__init__.py b/pandas/tests/arrays/string_/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py new file mode 100644 index 00000000000000..40221c34116aee --- /dev/null +++ b/pandas/tests/arrays/string_/test_string.py @@ -0,0 +1,160 @@ +import operator + +import numpy as np +import pytest + +import pandas as pd +import pandas.util.testing as tm + + +def test_none_to_nan(): + a = pd.arrays.StringArray._from_sequence(["a", None, "b"]) + assert a[1] is not None + assert np.isnan(a[1]) + + +def test_setitem_validates(): + a = pd.arrays.StringArray._from_sequence(["a", "b"]) + with pytest.raises(ValueError, match="10"): + a[0] = 10 + + with pytest.raises(ValueError, match="strings"): + a[:] = np.array([1, 2]) + + +@pytest.mark.parametrize( + "input, method", + [ + (["a", "b", "c"], operator.methodcaller("capitalize")), + (["a", "b", "c"], operator.methodcaller("capitalize")), + (["a b", "a bc. de"], operator.methodcaller("capitalize")), + ], +) +def test_string_methods(input, method): + a = pd.Series(input, dtype="string") + b = pd.Series(input, dtype="object") + result = method(a.str) + expected = method(b.str) + + assert result.dtype.name == "string" + tm.assert_series_equal(result.astype(object), expected) + + +def test_astype_roundtrip(): + s = pd.Series(pd.date_range("2000", periods=12)) + s[0] = None + + result = s.astype("string").astype("datetime64[ns]") + tm.assert_series_equal(result, s) + + +def test_add(): + a = pd.Series(["a", "b", "c", None, None], dtype="string") + b = pd.Series(["x", "y", None, "z", None], dtype="string") + + result = a + b + expected = pd.Series(["ax", "by", None, None, None], dtype="string") + tm.assert_series_equal(result, expected) + + result = a.add(b) + tm.assert_series_equal(result, expected) + + result = a.radd(b) + expected = pd.Series(["xa", "yb", None, None, None], dtype="string") + tm.assert_series_equal(result, expected) + + result = a.add(b, fill_value="-") + expected = pd.Series(["ax", "by", "c-", "-z", None], dtype="string") + tm.assert_series_equal(result, expected) + + +def test_add_2d(): + a = pd.array(["a", "b", "c"], dtype="string") + b = np.array([["a", "b", "c"]], dtype=object) + with pytest.raises(ValueError, match="3 != 1"): + a + b + + s = pd.Series(a) + with pytest.raises(ValueError, match="3 != 1"): + s + b + + +def test_add_sequence(): + a = pd.array(["a", "b", None, None], dtype="string") + other = ["x", None, "y", None] + + result = a + other + expected = pd.array(["ax", None, None, None], dtype="string") + tm.assert_extension_array_equal(result, expected) + + result = other + a + expected = pd.array(["xa", None, None, None], dtype="string") + tm.assert_extension_array_equal(result, expected) + + +def test_mul(): + a = pd.array(["a", "b", None], dtype="string") + result = a * 2 + expected = pd.array(["aa", "bb", None], dtype="string") + tm.assert_extension_array_equal(result, expected) + + result = 2 * a + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.xfail(reason="GH-28527") +def test_add_strings(): + array = pd.array(["a", "b", "c", "d"], dtype="string") + df = pd.DataFrame([["t", "u", "v", "w"]]) + assert array.__add__(df) is NotImplemented + + result = array + df + expected = pd.DataFrame([["at", "bu", "cv", "dw"]]).astype("string") + tm.assert_frame_equal(result, expected) + + result = df + array + expected = pd.DataFrame([["ta", "ub", "vc", "wd"]]).astype("string") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.xfail(reason="GH-28527") +def test_add_frame(): + array = pd.array(["a", "b", np.nan, np.nan], dtype="string") + df = pd.DataFrame([["x", np.nan, "y", np.nan]]) + + assert array.__add__(df) is NotImplemented + + result = array + df + expected = pd.DataFrame([["ax", np.nan, np.nan, np.nan]]).astype("string") + tm.assert_frame_equal(result, expected) + + result = df + array + expected = pd.DataFrame([["xa", np.nan, np.nan, np.nan]]).astype("string") + tm.assert_frame_equal(result, expected) + + +def test_constructor_raises(): + with pytest.raises(ValueError, match="sequence of strings"): + pd.arrays.StringArray(np.array(["a", "b"], dtype="S1")) + + with pytest.raises(ValueError, match="sequence of strings"): + pd.arrays.StringArray(np.array([])) + + +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.xfail(reason="Not implemented StringArray.sum") +def test_reduce(skipna): + arr = pd.Series(["a", "b", "c"], dtype="string") + result = arr.sum(skipna=skipna) + assert result == "abc" + + +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.xfail(reason="Not implemented StringArray.sum") +def test_reduce_missing(skipna): + arr = pd.Series([None, "a", None, "b", "c", None], dtype="string") + result = arr.sum(skipna=skipna) + if skipna: + assert result == "abc" + else: + assert pd.isna(result) diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 266f7ac50c6634..466b724f98770e 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -291,6 +291,8 @@ def test_is_string_dtype(): assert com.is_string_dtype(str) assert com.is_string_dtype(object) assert com.is_string_dtype(np.array(["a", "b"])) + assert com.is_string_dtype(pd.StringDtype()) + assert com.is_string_dtype(pd.array(["a", "b"], dtype="string")) def test_is_period_arraylike(): diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py new file mode 100644 index 00000000000000..5b872d5b722279 --- /dev/null +++ b/pandas/tests/extension/test_string.py @@ -0,0 +1,112 @@ +import string + +import numpy as np +import pytest + +import pandas as pd +from pandas.core.arrays.string_ import StringArray, StringDtype +from pandas.tests.extension import base + + +@pytest.fixture +def dtype(): + return StringDtype() + + +@pytest.fixture +def data(): + strings = np.random.choice(list(string.ascii_letters), size=100) + while strings[0] == strings[1]: + strings = np.random.choice(list(string.ascii_letters), size=100) + + return StringArray._from_sequence(strings) + + +@pytest.fixture +def data_missing(): + """Length 2 array with [NA, Valid]""" + return StringArray._from_sequence([np.nan, "A"]) + + +@pytest.fixture +def data_for_sorting(): + return StringArray._from_sequence(["B", "C", "A"]) + + +@pytest.fixture +def data_missing_for_sorting(): + return StringArray._from_sequence(["B", np.nan, "A"]) + + +@pytest.fixture +def na_value(): + return np.nan + + +@pytest.fixture +def data_for_grouping(): + return StringArray._from_sequence(["B", "B", np.nan, np.nan, "A", "A", "B", "C"]) + + +class TestDtype(base.BaseDtypeTests): + pass + + +class TestInterface(base.BaseInterfaceTests): + pass + + +class TestConstructors(base.BaseConstructorsTests): + pass + + +class TestReshaping(base.BaseReshapingTests): + pass + + +class TestGetitem(base.BaseGetitemTests): + pass + + +class TestSetitem(base.BaseSetitemTests): + pass + + +class TestMissing(base.BaseMissingTests): + pass + + +class TestNoReduce(base.BaseNoReduceTests): + pass + + +class TestMethods(base.BaseMethodsTests): + pass + + +class TestCasting(base.BaseCastingTests): + pass + + +class TestComparisonOps(base.BaseComparisonOpsTests): + def _compare_other(self, s, data, op_name, other): + result = getattr(s, op_name)(other) + expected = getattr(s.astype(object), op_name)(other) + self.assert_series_equal(result, expected) + + def test_compare_scalar(self, data, all_compare_operators): + op_name = all_compare_operators + s = pd.Series(data) + self._compare_other(s, data, op_name, "abc") + + +class TestParsing(base.BaseParsingTests): + pass + + +class TestPrinting(base.BasePrintingTests): + pass + + +class TestGroupBy(base.BaseGroupbyTests): + pass diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index bc8dc7272a83a3..b50f1a0fd2f2ab 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -6,6 +6,8 @@ from numpy.random import randint import pytest +from pandas._libs import lib + from pandas import DataFrame, Index, MultiIndex, Series, concat, isna, notna import pandas.core.strings as strings import pandas.util.testing as tm @@ -3269,3 +3271,25 @@ def test_casefold(self): result = s.str.casefold() tm.assert_series_equal(result, expected) + + +def test_string_array(any_string_method): + data = ["a", "bb", np.nan, "ccc"] + a = Series(data, dtype=object) + b = Series(data, dtype="string") + method_name, args, kwargs = any_string_method + + expected = getattr(a.str, method_name)(*args, **kwargs) + result = getattr(b.str, method_name)(*args, **kwargs) + + if isinstance(expected, Series): + if expected.dtype == "object" and lib.is_string_array( + expected.values, skipna=True + ): + assert result.dtype == "string" + result = result.astype(object) + elif isinstance(expected, DataFrame): + columns = expected.select_dtypes(include="object").columns + assert all(result[columns].dtypes == "string") + result[columns] = result[columns].astype(object) + tm.assert_equal(result, expected) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 32f88b13ac041f..a34fdee227afc9 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -1431,6 +1431,9 @@ def assert_equal(left, right, **kwargs): assert_extension_array_equal(left, right, **kwargs) elif isinstance(left, np.ndarray): assert_numpy_array_equal(left, right, **kwargs) + elif isinstance(left, str): + assert kwargs == {} + return left == right else: raise NotImplementedError(type(left))