Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the data_kind and validate_data_input functions #3335

Merged
merged 11 commits into from
Jul 20, 2024
12 changes: 10 additions & 2 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GMTVersionError,
)
from pygmt.helpers import (
_validate_data_input,
data_kind,
tempfile_from_geojson,
tempfile_from_image,
Expand Down Expand Up @@ -1684,8 +1685,15 @@ def virtualfile_in( # noqa: PLR0912
... print(fout.read().strip())
<vector memory>: N = 3 <7/9> <4/6> <1/3>
"""
kind = data_kind(
data, x, y, z, required_z=required_z, required_data=required_data
kind = data_kind(data, required=required_data)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)

if check_kind:
Expand Down
1 change: 1 addition & 0 deletions pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
unique_name,
)
from pygmt.helpers.utils import (
_validate_data_input,
args_in_kwargs,
build_arg_list,
build_arg_string,
Expand Down
71 changes: 29 additions & 42 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import warnings
import webbrowser
from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, Literal

import xarray as xr
from pygmt.encodings import charset
Expand Down Expand Up @@ -79,6 +79,10 @@ def _validate_data_input(
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
Traceback (most recent call last):
...
Expand Down Expand Up @@ -111,77 +115,69 @@ def _validate_data_input(
raise GMTInvalidInput("data must provide x, y, and z columns.")


def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
def data_kind(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this function only checks data, and x/y/z is no longer used.

data: Any = None, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
"""
Check what kind of data is provided to a module.
Check the kind of data that is provided to a module.

Possible types:
The ``data`` argument can be in any type, but only following types are supported:

* a file name provided as 'data'
* a pathlib.PurePath object provided as 'data'
* an xarray.DataArray object provided as 'data'
* a 2-D matrix provided as 'data'
* 1-D arrays x and y (and z, optionally)
* an optional argument (None, bool, int or float) provided as 'data'

Arguments should be ``None`` if not used. If doesn't fit any of these
categories (or fits more than one), will raise an exception.
- a string or a :class:`pathlib.PurePath` object or a sequence of them, representing
a file name or a list of file names
- a 2-D or 3-D :class:`xarray.DataArray` object
- a 2-D matrix
- None, bool, int or float type representing an optional arguments
- a geo-like Python object that implements ``__geo_interface__`` (e.g.,
geopandas.GeoDataFrame or shapely.geometry)

Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
x/y : 1-D arrays or None
x and y columns as numpy arrays.
z : 1-D array or None
z column as numpy array. To be used optionally when x and y are given.
required_z : bool
State whether the 'z' column is required.
required_data : bool
required
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].

Returns
-------
kind : str
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
``'matrix'``, or ``'vectors'``.
kind
The data kind.

Examples
--------

>>> import numpy as np
>>> import xarray as xr
>>> import pathlib
>>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6]))
>>> data_kind(data=None)
'vectors'
>>> data_kind(data=np.arange(10).reshape((5, 2)), x=None, y=None)
>>> data_kind(data=np.arange(10).reshape((5, 2)))
'matrix'
>>> data_kind(data="my-data-file.txt", x=None, y=None)
>>> data_kind(data="my-data-file.txt")
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
'file'
>>> data_kind(data=None, x=None, y=None, required_data=False)
>>> data_kind(data=None, required=False)
'arg'
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
>>> data_kind(data=2.0, required=False)
'arg'
>>> data_kind(data=True, x=None, y=None, required_data=False)
>>> data_kind(data=True, required=False)
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
'image'
"""
# determine the data kind
kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this line doing here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To suppress the mypy error:

pygmt/helpers/utils.py:191: error: Incompatible return value type (got "str", expected "Literal['arg', 'file', 'geojson', 'grid', 'image', 'matrix', 'vectors']")  [return-value]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah interesting. Probably should turn this into an enum, but leave that for another day.

if isinstance(data, str | pathlib.PurePath) or (
isinstance(data, list | tuple)
and all(isinstance(_file, str | pathlib.PurePath) for _file in data)
):
# One or more files
kind = "file"
elif isinstance(data, bool | int | float) or (data is None and not required_data):
elif isinstance(data, bool | int | float) or (data is None and not required):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "image" if len(data.dims) == 3 else "grid"
Expand All @@ -193,15 +189,6 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
kind = "matrix"
else:
kind = "vectors"
_validate_data_input(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from here and moved to Session.virtualfile_in.

data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)
return kind


Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y)
kind = data_kind(data)
extra_arrays = []
if kind == "vectors": # Add more columns for vectors input
# Parameters for vector styles
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def plot3d( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y, z)
kind = data_kind(data)
extra_arrays = []

if kind == "vectors": # Add more columns for vectors input
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def text_( # noqa: PLR0912

# Ensure inputs are either textfiles, x/y/text, or position/text
if position is None:
if (x is not None or y is not None) and textfiles is not None:
if any(v is not None for v in (x, y, text)) and textfiles is not None:
raise GMTInvalidInput(
"Provide either position only, or x/y pairs, or textfiles."
)
kind = data_kind(textfiles, x, y, text)
kind = data_kind(textfiles)
if kind == "vectors" and text is None:
raise GMTInvalidInput("Must provide text with x/y pairs")
else:
Expand Down
21 changes: 0 additions & 21 deletions pygmt/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pathlib import Path

import numpy as np
import pytest
import xarray as xr
from pygmt import Figure
Expand All @@ -13,7 +12,6 @@
GMTTempFile,
args_in_kwargs,
build_arg_list,
data_kind,
kwargs_to_strings,
unique_name,
)
Expand All @@ -33,25 +31,6 @@ def test_load_static_earth_relief():
assert isinstance(data, xr.DataArray)


@pytest.mark.parametrize(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is removed because:

  1. data_kind no longer checks if the input is valid
  2. All these tests are already covered by the validate_data_input doctests.

("data", "x", "y"),
[
(None, None, None),
("data.txt", np.array([1, 2]), np.array([4, 5])),
("data.txt", np.array([1, 2]), None),
("data.txt", None, np.array([4, 5])),
(None, np.array([1, 2]), None),
(None, None, np.array([4, 5])),
],
)
def test_data_kind_fails(data, x, y):
"""
Make sure data_kind raises exceptions when it should.
"""
with pytest.raises(GMTInvalidInput):
data_kind(data=data, x=x, y=y)


def test_unique_name():
"""
Make sure the names are really unique.
Expand Down