Skip to content

Commit

Permalink
Refactor the data_kind and validate_data_input functions (#3335)
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Jul 20, 2024
1 parent d3101f3 commit 9ec92be
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 69 deletions.
12 changes: 10 additions & 2 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GMTVersionError,
)
from pygmt.helpers import (
_validate_data_input,
data_kind,
tempfile_from_geojson,
tempfile_from_image,
Expand Down Expand Up @@ -1684,8 +1685,15 @@ def virtualfile_in( # noqa: PLR0912
... print(fout.read().strip())
<vector memory>: N = 3 <7/9> <4/6> <1/3>
"""
kind = data_kind(
data, x, y, z, required_z=required_z, required_data=required_data
kind = data_kind(data, required=required_data)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)

if check_kind:
Expand Down
1 change: 1 addition & 0 deletions pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
unique_name,
)
from pygmt.helpers.utils import (
_validate_data_input,
args_in_kwargs,
build_arg_list,
build_arg_string,
Expand Down
71 changes: 29 additions & 42 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import warnings
import webbrowser
from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, Literal

import xarray as xr
from pygmt.encodings import charset
Expand Down Expand Up @@ -79,6 +79,10 @@ def _validate_data_input(
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
Traceback (most recent call last):
...
Expand Down Expand Up @@ -111,77 +115,69 @@ def _validate_data_input(
raise GMTInvalidInput("data must provide x, y, and z columns.")


def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
def data_kind(
data: Any = None, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
"""
Check what kind of data is provided to a module.
Check the kind of data that is provided to a module.
Possible types:
The ``data`` argument can be in any type, but only following types are supported:
* a file name provided as 'data'
* a pathlib.PurePath object provided as 'data'
* an xarray.DataArray object provided as 'data'
* a 2-D matrix provided as 'data'
* 1-D arrays x and y (and z, optionally)
* an optional argument (None, bool, int or float) provided as 'data'
Arguments should be ``None`` if not used. If doesn't fit any of these
categories (or fits more than one), will raise an exception.
- a string or a :class:`pathlib.PurePath` object or a sequence of them, representing
a file name or a list of file names
- a 2-D or 3-D :class:`xarray.DataArray` object
- a 2-D matrix
- None, bool, int or float type representing an optional arguments
- a geo-like Python object that implements ``__geo_interface__`` (e.g.,
geopandas.GeoDataFrame or shapely.geometry)
Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
x/y : 1-D arrays or None
x and y columns as numpy arrays.
z : 1-D array or None
z column as numpy array. To be used optionally when x and y are given.
required_z : bool
State whether the 'z' column is required.
required_data : bool
required
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Returns
-------
kind : str
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
``'matrix'``, or ``'vectors'``.
kind
The data kind.
Examples
--------
>>> import numpy as np
>>> import xarray as xr
>>> import pathlib
>>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6]))
>>> data_kind(data=None)
'vectors'
>>> data_kind(data=np.arange(10).reshape((5, 2)), x=None, y=None)
>>> data_kind(data=np.arange(10).reshape((5, 2)))
'matrix'
>>> data_kind(data="my-data-file.txt", x=None, y=None)
>>> data_kind(data="my-data-file.txt")
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
'file'
>>> data_kind(data=None, x=None, y=None, required_data=False)
>>> data_kind(data=None, required=False)
'arg'
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
>>> data_kind(data=2.0, required=False)
'arg'
>>> data_kind(data=True, x=None, y=None, required_data=False)
>>> data_kind(data=True, required=False)
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
'image'
"""
# determine the data kind
kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]
if isinstance(data, str | pathlib.PurePath) or (
isinstance(data, list | tuple)
and all(isinstance(_file, str | pathlib.PurePath) for _file in data)
):
# One or more files
kind = "file"
elif isinstance(data, bool | int | float) or (data is None and not required_data):
elif isinstance(data, bool | int | float) or (data is None and not required):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "image" if len(data.dims) == 3 else "grid"
Expand All @@ -193,15 +189,6 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
kind = "matrix"
else:
kind = "vectors"
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)
return kind


Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y)
kind = data_kind(data)
extra_arrays = []
if kind == "vectors": # Add more columns for vectors input
# Parameters for vector styles
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def plot3d( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y, z)
kind = data_kind(data)
extra_arrays = []

if kind == "vectors": # Add more columns for vectors input
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def text_( # noqa: PLR0912

# Ensure inputs are either textfiles, x/y/text, or position/text
if position is None:
if (x is not None or y is not None) and textfiles is not None:
if any(v is not None for v in (x, y, text)) and textfiles is not None:
raise GMTInvalidInput(
"Provide either position only, or x/y pairs, or textfiles."
)
kind = data_kind(textfiles, x, y, text)
kind = data_kind(textfiles)
if kind == "vectors" and text is None:
raise GMTInvalidInput("Must provide text with x/y pairs")
else:
Expand Down
21 changes: 0 additions & 21 deletions pygmt/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pathlib import Path

import numpy as np
import pytest
import xarray as xr
from pygmt import Figure
Expand All @@ -13,7 +12,6 @@
GMTTempFile,
args_in_kwargs,
build_arg_list,
data_kind,
kwargs_to_strings,
unique_name,
)
Expand All @@ -33,25 +31,6 @@ def test_load_static_earth_relief():
assert isinstance(data, xr.DataArray)


@pytest.mark.parametrize(
("data", "x", "y"),
[
(None, None, None),
("data.txt", np.array([1, 2]), np.array([4, 5])),
("data.txt", np.array([1, 2]), None),
("data.txt", None, np.array([4, 5])),
(None, np.array([1, 2]), None),
(None, None, np.array([4, 5])),
],
)
def test_data_kind_fails(data, x, y):
"""
Make sure data_kind raises exceptions when it should.
"""
with pytest.raises(GMTInvalidInput):
data_kind(data=data, x=x, y=y)


def test_unique_name():
"""
Make sure the names are really unique.
Expand Down

0 comments on commit 9ec92be

Please sign in to comment.