Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the virtualfile_in function to accept more 1-D arrays #2744

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
66c4b97
Refactor the data_kind and the virtualfile_to_data functions
seisman Oct 13, 2023
78c28cd
Update more functions
seisman Oct 14, 2023
f849e5a
Merge branch 'main' into refactor/virtualfile-to-data
seisman Oct 15, 2023
f37413b
Change ncols to names
seisman Oct 15, 2023
3de7666
Fix more tests
seisman Oct 15, 2023
93b91d0
Fix project
seisman Oct 15, 2023
2eecf48
Merge branch 'main' into refactor/virtualfile-to-data
seisman Oct 16, 2023
1d6e568
Fix more tests
seisman Oct 16, 2023
6f9fc19
Fixes
seisman Oct 16, 2023
68034ed
Merge branch 'main' into refactor/virtualfile-to-data
seisman Oct 17, 2023
0db21bc
Fix triangulate
seisman Oct 17, 2023
7cf5290
Fix text
seisman Oct 17, 2023
b0b6d2a
Fix more failing tests
seisman Oct 17, 2023
fa875ef
More fixes
seisman Oct 17, 2023
2ee0df2
Fix linting issues
seisman Oct 17, 2023
d5c8340
Fix linting issues
seisman Oct 17, 2023
30bacb1
Fix linting issues
seisman Oct 18, 2023
4465f9b
Merge branch 'main' into refactor/virtualfile-to-data
seisman Oct 20, 2023
593f252
Update pygmt/clib/session.py
seisman Oct 20, 2023
409337f
Apply suggestions from code review
seisman Oct 25, 2023
872fd59
Merge branch 'main' into refactor/virtualfile-to-data
seisman Dec 25, 2023
3ed0eb2
Merge branch 'main' into refactor/virtualfile-to-data
seisman Jan 16, 2024
efa7a11
Merge branch 'main' into refactor/virtualfile-to-data
seisman Jan 18, 2024
23fc3ea
Merge branch 'main' into refactor/virtualfile-to-data
seisman Mar 1, 2024
aa05333
Merge branch 'main' into refactor/virtualfile-to-data
seisman Jul 11, 2024
5c10fc4
Fix plot and plot3d
seisman Jul 11, 2024
525a353
Fix errors in merging the main branch
seisman Jul 11, 2024
2f3fcc4
Merge branch 'main' into refactor/virtualfile-to-data
seisman Jul 20, 2024
b55a9ad
Fix merging issue
seisman Jul 20, 2024
46be0fa
Merge branch 'main' into refactor/virtualfile-to-data
seisman Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,15 +1621,12 @@ def virtualfile_from_grid(self, grid):
with self.open_virtualfile(*args) as vfile:
yield vfile

def virtualfile_in( # noqa: PLR0912
def virtualfile_in(
self,
check_kind=None,
data=None,
x=None,
y=None,
z=None,
extra_arrays=None,
required_z=False,
vectors=None,
names="xy",
required_data=True,
):
"""
Expand All @@ -1648,13 +1645,12 @@ def virtualfile_in( # noqa: PLR0912
Any raster or vector data format. This could be a file name or
path, a raster grid, a vector matrix/arrays, or other supported
data input.
x/y/z : 1-D arrays or None
x, y, and z columns as numpy arrays.
extra_arrays : list of 1-D arrays
Optional. A list of numpy arrays in addition to x, y, and z.
All of these arrays must be of the same size as the x/y/z arrays.
required_z : bool
State whether the 'z' column is required.
vectors : list of 1-D arrays or None
A list of 1-D arrays. Each array will be a column in the table.
All of these arrays must be of the same size.
names : str or list of str
A list of names for each of the columns. Must be of the same size
as the number of vectors. Default is ``"xy"``.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Expand Down Expand Up @@ -1688,10 +1684,8 @@ def virtualfile_in( # noqa: PLR0912
kind = data_kind(data, required=required_data)
_validate_data_input(
data=data,
x=x,
y=y,
z=z,
required_z=required_z,
vectors=vectors,
names=names,
required_data=required_data,
kind=kind,
)
Expand Down Expand Up @@ -1734,11 +1728,7 @@ def virtualfile_in( # noqa: PLR0912
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
elif kind == "vectors":
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
_data.append(np.atleast_1d(z))
if extra_arrays:
_data.extend(extra_arrays)
_data = [np.atleast_1d(v) for v in vectors]
elif kind == "matrix": # turn 2-D arrays into list of vectors
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# pandas.DataFrame or xarray.Dataset types.
Expand Down
132 changes: 79 additions & 53 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,99 +20,126 @@


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
data=None, vectors=None, names="xy", required_data=True, kind=None
):
"""
Check if the combination of data/x/y/z is valid.
Check if the data input is valid.

Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
vectors : list of 1-D arrays
A list of 1-D arrays with the data columns.
names : list of str
List of column names.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files [Default is True].
kind : str or None
The kind of data that will be passed to a module. If not given, it
will be determined by calling :func:`data_kind`.

Examples
--------
>>> _validate_data_input(data="infile")
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6])
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9])
>>> _validate_data_input(vectors=[[1, 2, 3], [4, 5, 6]], names="xy")
>>> _validate_data_input(vectors=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], names="xyz")
>>> _validate_data_input(data=None, required_data=False)
>>> _validate_data_input()
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: No input data provided.
>>> _validate_data_input(x=[1, 2, 3])
>>> _validate_data_input(vectors=[[1, 2, 3], None], names="xy")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(y=[4, 5, 6])
pygmt.exceptions.GMTInvalidInput: Column 1 ('y') can't be None.
>>> _validate_data_input(vectors=[None, [4, 5, 6]], names="xy")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
pygmt.exceptions.GMTInvalidInput: Column 0 ('x') can't be None.
>>> _validate_data_input(vectors=[[1, 2, 3], [4, 5, 6], None], names="xyz")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
pygmt.exceptions.GMTInvalidInput: Column 2 ('z') can't be None.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
>>> _validate_data_input(data=data, names="xyz", kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: data must have at least 3 columns.
x y z
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... names="xyz",
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: data must have at least 3 columns.
x y z
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... names="xyz",
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
pygmt.exceptions.GMTInvalidInput: data must have at least 3 columns.
x y z
>>> _validate_data_input(data="infile", vectors=[[1, 2, 3], None])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
pygmt...GMTInvalidInput: Too much data. Use either 'data' or 1-D arrays.
>>> _validate_data_input(data="infile", vectors=[None, None, [7, 8, 9]])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
>>> _validate_data_input(data="infile", z=[7, 8, 9])
pygmt...GMTInvalidInput: Too much data. Use either 'data' or 1-D arrays.
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
pygmt...GMTInvalidInput: Too much data. Use either 'data' or 1-D arrays.

Raises
------
GMTInvalidInput
If the data input is not valid.
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required_data: # data is not optional
raise GMTInvalidInput("No input data provided.")
elif x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
if required_z and z is None: # both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
else: # data is not None
if x is not None or y is not None or z is not None:
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
# For 'matrix' kind, check if data has the required z column
if kind == "matrix" and required_z:
if kind is None:
kind = data_kind(data=data, required=required_data)

if kind == "vectors": # From data_kind, we know that data is None
if vectors is None:
raise GMTInvalidInput("No input data provided.")
if len(vectors) < len(names):
raise GMTInvalidInput(
f"Requires {len(names)} 1-D arrays but got {len(vectors)}."
)
Comment on lines +122 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing unit test for this if-condition.

for i, v in enumerate(vectors[: len(names)]):
if v is None:
raise GMTInvalidInput(f"Column {i} ('{names[i]}') can't be None.")
else:
if vectors is not None and any(v is not None for v in vectors):
raise GMTInvalidInput("Too much data. Use either 'data' or 1-D arrays.")
if kind == "matrix": # check number of columns for matrix-like data
msg = f"data must have at least {len(names)} columns.\n" + " ".join(names)
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
if len(data.shape) == 1 and data.shape[0] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if len(data.shape) > 1 and data.shape[1] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput("data must provide x, y, and z columns.")
if len(data.shape) == 1 and data.shape[0] < len(names):
raise GMTInvalidInput(msg)

Check warning on line 136 in pygmt/helpers/utils.py

View check run for this annotation

Codecov / codecov/patch

pygmt/helpers/utils.py#L136

Added line #L136 was not covered by tests
Comment on lines +135 to +136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing unit test for this if-condition.

if len(data.shape) > 1 and data.shape[1] < len(names):
raise GMTInvalidInput(msg)
if hasattr(data, "data_vars") and len(data.data_vars) < len(
names
): # xr.Dataset
raise GMTInvalidInput(msg)


def _check_encoding(
Expand Down Expand Up @@ -189,19 +216,18 @@

def data_kind(
data: Any = None, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
) -> Literal["arg", "file", "grid", "image", "matrix", "vectors"]:
"""
Check the kind of data that is provided to a module.
Check what kind of data is provided to a module.

The ``data`` argument can be in any type, but only following types are supported:
Possible types:

- a string or a :class:`pathlib.PurePath` object or a sequence of them, representing
a file name or a list of file names
- a 2-D or 3-D :class:`xarray.DataArray` object
- a 2-D matrix
- None, bool, int or float type representing an optional arguments
- a geo-like Python object that implements ``__geo_interface__`` (e.g.,
geopandas.GeoDataFrame or shapely.geometry)
* a file name provided as 'data'
* a pathlib.PurePath object provided as 'data'
* an xarray.DataArray object provided as 'data'
* a 2-D matrix provided as 'data'
* 1-D arrays x and y (and z, optionally)
* an optional argument (None, bool, int or float) provided as 'data'

Parameters
----------
Expand Down Expand Up @@ -257,9 +283,9 @@
# geo-like Python object that implements ``__geo_interface__``
# (geopandas.GeoDataFrame or shapely.geometry)
kind = "geojson"
elif data is not None:
elif data is not None: # anything but None is taken as a matrix
kind = "matrix"
else:
else: # fallback to vectors if data is None but required
kind = "vectors"
return kind

Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _blockm(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
) as vintbl:
lib.call_module(
module="contour", args=build_arg_list(kwargs, infile=vintbl)
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/nearneighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def nearneighbor(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
33 changes: 19 additions & 14 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which
from pygmt.src import which


@fmt_docstring
Expand Down Expand Up @@ -50,9 +50,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", c="sequence_comma", i="sequence_comma", p="sequence")
def plot( # noqa: PLR0912
self, data=None, x=None, y=None, size=None, direction=None, **kwargs
):
def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
r"""
Plot lines, polygons, and symbols in 2-D.

Expand Down Expand Up @@ -209,26 +207,32 @@ def plot( # noqa: PLR0912
kwargs = self._preprocess(**kwargs)

kind = data_kind(data)
extra_arrays = []
vectors = [x, y]
names = ["x", "y"]

if kind == "vectors": # Add more columns for vectors input
# Parameters for vector styles
if (
kwargs.get("S") is not None
and kwargs["S"][0] in "vV"
and is_nonstr_iter(direction)
):
extra_arrays.extend(direction)
vectors.extend(direction)
names.extend(["x2", "y2"])
# Fill
if is_nonstr_iter(kwargs.get("G")):
extra_arrays.append(kwargs.get("G"))
vectors.append(kwargs["G"])
names.append("fill")
del kwargs["G"]
# Size
if is_nonstr_iter(size):
extra_arrays.append(size)
vectors.append(size)
names.append("size")
# Intensity and transparency
for flag in ["I", "t"]:
if is_nonstr_iter(kwargs.get(flag)):
extra_arrays.append(kwargs.get(flag))
vectors.append(kwargs[flag])
names.append(plot.aliases[flag])
kwargs[flag] = ""
else:
for name, value in [
Expand All @@ -240,7 +244,6 @@ def plot( # noqa: PLR0912
]:
if is_nonstr_iter(value):
raise GMTInvalidInput(f"'{name}' can't be 1-D array if 'data' is used.")

# Set the default style if data has a geometry of Point or MultiPoint
if kwargs.get("S") is None:
if kind == "geojson" and data.geom_type.isin(["Point", "MultiPoint"]).all():
Expand All @@ -255,7 +258,9 @@ def plot( # noqa: PLR0912
pass

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
) as vintbl:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))
file_context = lib.virtualfile_in(
check_kind="vector", data=data, vectors=vectors, names=names
)

with file_context as fname:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=fname))
Loading
Loading