diff --git a/spatialpandas/io/parquet.py b/spatialpandas/io/parquet.py index 839b670..4d6dc8a 100644 --- a/spatialpandas/io/parquet.py +++ b/spatialpandas/io/parquet.py @@ -8,6 +8,7 @@ import fsspec import pandas as pd +import pyarrow as pa from dask import delayed from dask.dataframe import ( from_delayed, @@ -107,6 +108,7 @@ def read_parquet( filesystem: Optional[fsspec.AbstractFileSystem] = None, storage_options: Optional[dict[str, Any]] = None, engine_kwargs: Optional[dict[str, Any]] = None, + convert_string: bool = False, **kwargs: Any, ) -> GeoDataFrame: engine_kwargs = engine_kwargs or {} @@ -139,7 +141,11 @@ def read_parquet( columns = extra_index_columns + list(columns) - df = dataset.read(columns=columns).to_pandas() + # References: + # https://arrow.apache.org/docs/python/pandas.html + # https://pandas.pydata.org/docs/user_guide/pyarrow.html + type_mapping = {pa.string(): pd.StringDtype("pyarrow"), pa.large_string(): pd.StringDtype("pyarrow")} if convert_string else {} + df = dataset.read(columns=columns).to_pandas(types_mapper=type_mapping.get, self_destruct=True) # Return result return GeoDataFrame(df) @@ -332,6 +338,13 @@ def _perform_read_parquet_dask( dataset_pieces = sorted(fragments, key=lambda piece: natural_sort_key(piece.path)) pieces.extend(dataset_pieces) + # https://docs.dask.org/en/stable/changelog.html#v2023-7-1 + try: + from dask.dataframe.utils import pyarrow_strings_enabled + + convert_string = pyarrow_strings_enabled() + except (ImportError, RuntimeError): + convert_string = False delayed_partitions = [ delayed(read_parquet)( piece.path, @@ -339,6 +352,7 @@ def _perform_read_parquet_dask( filesystem=filesystem, storage_options=storage_options, engine_kwargs=engine_kwargs, + convert_string=convert_string, ) for piece in pieces ] diff --git a/spatialpandas/tests/test_parquet.py b/spatialpandas/tests/test_parquet.py index ab1a14c..a575321 100644 --- a/spatialpandas/tests/test_parquet.py +++ b/spatialpandas/tests/test_parquet.py @@ -9,7 +9,7 @@ import pytest from hypothesis import HealthCheck, Phase, Verbosity, given, settings -from spatialpandas import GeoDataFrame, GeoSeries +from spatialpandas import GeoDataFrame, GeoSeries, geometry from spatialpandas.dask import DaskGeoDataFrame from spatialpandas.io import read_parquet, read_parquet_dask, to_parquet @@ -440,3 +440,31 @@ def test_read_parquet_dask(directory, repartitioned): assert partition_bounds["multiline"].index.name == "partition" assert ddf["multiline"].partition_bounds.index.name == "partition" assert all(partition_bounds["multiline"] == ddf["multiline"].partition_bounds) + + +def test_parquet_dask_string_conversion(): + from dask.dataframe.utils import pyarrow_strings_enabled + + assert isinstance(pyarrow_strings_enabled(), bool) + + +@pytest.mark.parametrize("save_convert_string", [True, False]) +@pytest.mark.parametrize("load_convert_string", [True, False]) +def test_parquet_dask_string_convert(save_convert_string, load_convert_string, tmp_path): + with dask.config.set({"dataframe.convert-string": save_convert_string}): + square = geometry.Polygon([(0, 0), (0, 1), (1, 1), (1, 0)]) + sdf = GeoDataFrame({"geometry": GeoSeries([square, square]), "name": ["A", "B"]}) + sddf = dd.from_pandas(sdf, 2) + sddf.to_parquet(tmp_path / "test.parq") + + if load_convert_string: + dtype = pd.StringDtype("pyarrow") + elif save_convert_string: + dtype = pd.StringDtype("python") + else: + dtype = np.dtype("object") + + with dask.config.set({"dataframe.convert-string": load_convert_string}): + data = read_parquet_dask(tmp_path / "test.parq") + assert data["name"].dtype == dtype + assert data.compute()["name"].dtype == dtype