Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compat: Ensure that pandas dtype matches dask when loading data from parquet #156

Merged
merged 10 commits into from
Nov 12, 2024
16 changes: 15 additions & 1 deletion spatialpandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import fsspec
import pandas as pd
import pyarrow as pa
from dask import delayed
from dask.dataframe import (
from_delayed,
Expand Down Expand Up @@ -107,6 +108,7 @@ def read_parquet(
filesystem: Optional[fsspec.AbstractFileSystem] = None,
storage_options: Optional[dict[str, Any]] = None,
engine_kwargs: Optional[dict[str, Any]] = None,
convert_string: bool = False,
**kwargs: Any,
) -> GeoDataFrame:
engine_kwargs = engine_kwargs or {}
Expand Down Expand Up @@ -139,7 +141,11 @@ def read_parquet(

columns = extra_index_columns + list(columns)

df = dataset.read(columns=columns).to_pandas()
# References:
# https://arrow.apache.org/docs/python/pandas.html
# https://pandas.pydata.org/docs/user_guide/pyarrow.html
type_mapping = {pa.string(): pd.StringDtype("pyarrow"), pa.large_string(): pd.StringDtype("pyarrow")} if convert_string else {}
df = dataset.read(columns=columns).to_pandas(types_mapper=type_mapping.get, self_destruct=True)

# Return result
return GeoDataFrame(df)
Expand Down Expand Up @@ -332,13 +338,21 @@ def _perform_read_parquet_dask(
dataset_pieces = sorted(fragments, key=lambda piece: natural_sort_key(piece.path))
pieces.extend(dataset_pieces)

# https://docs.dask.org/en/stable/changelog.html#v2023-7-1
try:
from dask.dataframe.utils import pyarrow_strings_enabled
hoxbro marked this conversation as resolved.
Show resolved Hide resolved

convert_string = pyarrow_strings_enabled()
except (ImportError, RuntimeError):
convert_string = False
delayed_partitions = [
delayed(read_parquet)(
piece.path,
columns=columns,
filesystem=filesystem,
storage_options=storage_options,
engine_kwargs=engine_kwargs,
convert_string=convert_string,
)
for piece in pieces
]
Expand Down
30 changes: 29 additions & 1 deletion spatialpandas/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from hypothesis import HealthCheck, Phase, Verbosity, given, settings

from spatialpandas import GeoDataFrame, GeoSeries
from spatialpandas import GeoDataFrame, GeoSeries, geometry
from spatialpandas.dask import DaskGeoDataFrame
from spatialpandas.io import read_parquet, read_parquet_dask, to_parquet

Expand Down Expand Up @@ -440,3 +440,31 @@ def test_read_parquet_dask(directory, repartitioned):
assert partition_bounds["multiline"].index.name == "partition"
assert ddf["multiline"].partition_bounds.index.name == "partition"
assert all(partition_bounds["multiline"] == ddf["multiline"].partition_bounds)


def test_parquet_dask_string_conversion():
from dask.dataframe.utils import pyarrow_strings_enabled

assert isinstance(pyarrow_strings_enabled(), bool)


@pytest.mark.parametrize("save_convert_string", [True, False])
@pytest.mark.parametrize("load_convert_string", [True, False])
def test_parquet_dask_string_convert(save_convert_string, load_convert_string, tmp_path):
with dask.config.set({"dataframe.convert-string": save_convert_string}):
square = geometry.Polygon([(0, 0), (0, 1), (1, 1), (1, 0)])
sdf = GeoDataFrame({"geometry": GeoSeries([square, square]), "name": ["A", "B"]})
sddf = dd.from_pandas(sdf, 2)
sddf.to_parquet(tmp_path / "test.parq")

if load_convert_string:
dtype = pd.StringDtype("pyarrow")
elif save_convert_string:
dtype = pd.StringDtype("python")
else:
dtype = np.dtype("object")

with dask.config.set({"dataframe.convert-string": load_convert_string}):
data = read_parquet_dask(tmp_path / "test.parq")
assert data["name"].dtype == dtype
assert data.compute()["name"].dtype == dtype