Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the data_kind and validate_data_input functions #3335

Merged
merged 11 commits into from
Jul 20, 2024
12 changes: 10 additions & 2 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
data_kind,
tempfile_from_geojson,
tempfile_from_image,
validate_data_input,
)

FAMILIES = [
Expand Down Expand Up @@ -1591,8 +1592,15 @@ def virtualfile_in( # noqa: PLR0912
... print(fout.read().strip())
<vector memory>: N = 3 <7/9> <4/6> <1/3>
"""
kind = data_kind(
data, x, y, z, required_z=required_z, required_data=required_data
kind = data_kind(data, required=required_data)
validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)

if check_kind:
Expand Down
1 change: 1 addition & 0 deletions pygmt/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
is_nonstr_iter,
launch_external_viewer,
non_ascii_to_octal,
validate_data_input,
)
from pygmt.helpers.validators import validate_output_table_type
82 changes: 34 additions & 48 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,74 +12,78 @@
import warnings
import webbrowser
from collections.abc import Iterable, Sequence
from typing import Any
from typing import Any, Literal

import xarray as xr
from pygmt.encodings import charset
from pygmt.exceptions import GMTInvalidInput


def _validate_data_input(
def validate_data_input(
seisman marked this conversation as resolved.
Show resolved Hide resolved
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
):
"""
Check if the combination of data/x/y/z is valid.

Examples
--------
>>> _validate_data_input(data="infile")
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> _validate_data_input(data=None, required_data=False)
>>> _validate_data_input()
>>> validate_data_input(data="infile")
>>> validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> validate_data_input(data=None, required_data=False)
>>> validate_data_input()
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: No input data provided.
>>> _validate_data_input(x=[1, 2, 3])
>>> validate_data_input(x=[1, 2, 3])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(y=[4, 5, 6])
>>> validate_data_input(y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
>>> validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
>>> validate_data_input(data=data, required_z=True, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
>>> validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
>>> validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
>>> validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", y=[4, 5, 6])
>>> validate_data_input(data="infile", y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
>>> validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> validate_data_input(data="infile", z=[7, 8, 9])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
Expand Down Expand Up @@ -111,7 +115,9 @@ def _validate_data_input(
raise GMTInvalidInput("data must provide x, y, and z columns.")


def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
def data_kind(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this function only checks data, and x/y/z is no longer used.

data: Any = None, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
"""
Check what kind of data is provided to a module.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since x/y/z is no longer used, at L124-129, change the docstring to read:

    Possible types (provided to `data`):
    * a file name
    * a pathlib.PurePath object
    * an xarray.DataArray object
    * a 2-D matrix

Expand All @@ -124,50 +130,39 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
* 1-D arrays x and y (and z, optionally)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this line about x/y/z?

Suggested change
* 1-D arrays x and y (and z, optionally)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docstring at b575591. Currently, any unrecognized data type is taken as a matrix type. It's not ideal, so I plan to refactor the data_kind function later, and will also polish the docstring when refactoring.

* an optional argument (None, bool, int or float) provided as 'data'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* an optional argument (None, bool, int or float) provided as 'data'
* an optional argument (None, bool, int or float)


Arguments should be ``None`` if not used. If doesn't fit any of these
categories (or fits more than one), will raise an exception.

Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
x/y : 1-D arrays or None
x and y columns as numpy arrays.
z : 1-D array or None
z column as numpy array. To be used optionally when x and y are given.
required_z : bool
State whether the 'z' column is required.
required_data : bool
required
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].

Returns
-------
kind : str
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
``'matrix'``, or ``'vectors'``.
kind
The data kind.

Examples
--------

>>> import numpy as np
>>> import xarray as xr
>>> import pathlib
>>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6]))
>>> data_kind(data=None)
'vectors'
>>> data_kind(data=np.arange(10).reshape((5, 2)), x=None, y=None)
>>> data_kind(data=np.arange(10).reshape((5, 2)))
'matrix'
>>> data_kind(data="my-data-file.txt", x=None, y=None)
>>> data_kind(data="my-data-file.txt")
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
'file'
>>> data_kind(data=None, x=None, y=None, required_data=False)
>>> data_kind(data=None, required=False)
'arg'
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
>>> data_kind(data=2.0, required=False)
'arg'
>>> data_kind(data=True, x=None, y=None, required_data=False)
>>> data_kind(data=True, required=False)
'arg'
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
'grid'
Expand All @@ -181,7 +176,7 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
):
# One or more files
kind = "file"
elif isinstance(data, bool | int | float) or (data is None and not required_data):
elif isinstance(data, bool | int | float) or (data is None and not required):
kind = "arg"
elif isinstance(data, xr.DataArray):
kind = "image" if len(data.dims) == 3 else "grid"
Expand All @@ -193,15 +188,6 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
kind = "matrix"
else:
kind = "vectors"
_validate_data_input(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from here and moved to Session.virtualfile_in.

data=data,
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
kind=kind,
)
return kind


Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y)
kind = data_kind(data)
extra_arrays = []
if kind == "vectors": # Add more columns for vectors input
# Parameters for vector styles
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def plot3d( # noqa: PLR0912
"""
kwargs = self._preprocess(**kwargs)

kind = data_kind(data, x, y, z)
kind = data_kind(data)
extra_arrays = []

if kind == "vectors": # Add more columns for vectors input
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ def text_( # noqa: PLR0912

# Ensure inputs are either textfiles, x/y/text, or position/text
if position is None:
if (x is not None or y is not None) and textfiles is not None:
if any(v is not None for v in (x, y, text)) and textfiles is not None:
raise GMTInvalidInput(
"Provide either position only, or x/y pairs, or textfiles."
)
kind = data_kind(textfiles, x, y, text)
kind = data_kind(textfiles)
if kind == "vectors" and text is None:
raise GMTInvalidInput("Must provide text with x/y pairs")
else:
Expand Down
28 changes: 1 addition & 27 deletions pygmt/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@

from pathlib import Path

import numpy as np
import pytest
import xarray as xr
from pygmt import Figure
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
args_in_kwargs,
data_kind,
kwargs_to_strings,
unique_name,
)
from pygmt.helpers import GMTTempFile, args_in_kwargs, kwargs_to_strings, unique_name
from pygmt.helpers.testing import load_static_earth_relief, skip_if_no


Expand All @@ -32,25 +25,6 @@ def test_load_static_earth_relief():
assert isinstance(data, xr.DataArray)


@pytest.mark.parametrize(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is removed because:

  1. data_kind no longer checks if the input is valid
  2. All these tests are already covered by the validate_data_input doctests.

("data", "x", "y"),
[
(None, None, None),
("data.txt", np.array([1, 2]), np.array([4, 5])),
("data.txt", np.array([1, 2]), None),
("data.txt", None, np.array([4, 5])),
(None, np.array([1, 2]), None),
(None, None, np.array([4, 5])),
],
)
def test_data_kind_fails(data, x, y):
"""
Make sure data_kind raises exceptions when it should.
"""
with pytest.raises(GMTInvalidInput):
data_kind(data=data, x=x, y=y)


def test_unique_name():
"""
Make sure the names are really unique.
Expand Down
Loading