Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Session.virtualfile_in, removing 'extra_arrays'/'required_z' and add 'required_cols' #3369

Draft
wants to merge 4 commits into
base: refactor/data_kind
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,9 +1605,8 @@
x=None,
y=None,
z=None,
extra_arrays=None,
required_z=False,
required_data=True,
required_cols: int = 2,
):
"""
Store any data inside a virtual file.
Expand All @@ -1627,14 +1626,11 @@
data input.
x/y/z : 1-D arrays or None
x, y, and z columns as numpy arrays.
extra_arrays : list of 1-D arrays
Optional. A list of numpy arrays in addition to x, y, and z.
All of these arrays must be of the same size as the x/y/z arrays.
required_z : bool
State whether the 'z' column is required.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
required_cols
Number of required columns.

Returns
-------
Expand Down Expand Up @@ -1668,8 +1664,8 @@
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
required_cols=required_cols,
kind=kind,
)

Expand Down Expand Up @@ -1723,8 +1719,6 @@
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
_data.append(np.atleast_1d(z))
if extra_arrays:
_data.extend(extra_arrays)
case "vectors":
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# Dict, pandas.DataFrame or xarray.Dataset types.
Expand Down Expand Up @@ -1757,20 +1751,32 @@
instead.
"""
msg = (
"API function 'Session.virtualfile_from_datae()' has been deprecated since "
"API function 'Session.virtualfile_from_data()' has been deprecated since "
"v0.13.0 and will be removed in v0.15.0. Use 'Session.virtualfile_in()' "
"instead."
)
warnings.warn(msg, category=FutureWarning, stacklevel=2)
# Session.virtualfile_in no longer has the 'extra_arrays' parameter.
if data is None and extra_arrays is not None:
data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
data.append(np.atleast_1d(z))

Check warning on line 1763 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1763

Added line #L1763 was not covered by tests
data.extend(extra_arrays)
x, y, z = None, None, None

# Need to convert the list of arrays into a pandas.DataFrame object.
# Otherwise, the "vector" `data` will be converted to a homogeneous 2D
# numpy.ndarray first.
data = pd.concat(objs=[pd.Series(array) for array in data], axis="columns")

return self.virtualfile_in(
check_kind=check_kind,
data=data,
x=x,
y=y,
z=z,
extra_arrays=extra_arrays,
required_z=required_z,
required_data=required_data,
required_cols=3 if required_z else 2,
)

@contextlib.contextmanager
Expand Down
67 changes: 39 additions & 28 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
):
"""
Check if the combination of data/x/y/z is valid.
Expand All @@ -44,34 +44,33 @@ def _validate_data_input(
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... required_cols=3,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... kind="matrix",
... required_cols=3,
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
>>> _validate_data_input(data="infile", x=[1, 2, 3])
Traceback (most recent call last):
...
Expand All @@ -94,26 +93,38 @@ def _validate_data_input(
GMTInvalidInput
If the data input is not valid.
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required_data: # data is not optional
if kind is None:
kind = data_kind(data, required=required_data)

if data is not None and any(v is not None for v in (x, y, z)):
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")

match kind:
case "none":
if x is None and y is None: # both x and y are None
raise GMTInvalidInput("No input data provided.")
elif x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
if required_z and z is None: # both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
else: # data is not None
if x is not None or y is not None or z is not None:
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
# For 'matrix' kind, check if data has the required z column
if kind == "matrix" and required_z:
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
if len(data.shape) == 1 and data.shape[0] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if len(data.shape) > 1 and data.shape[1] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput("data must provide x, y, and z columns.")
if x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
if required_cols >= 3 and z is None:
# both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)
case "vectors":
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# Dict, pd.DataFrame, xr.Dataset
arrays = [array for _, array in data.items()]
if (actual_cols := len(arrays)) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)

# Loop over columns to make sure they're not None
for idx, array in enumerate(arrays[:required_cols]):
if array is None:
msg = f"data needs {required_cols} columns but the {idx} column is None."
raise GMTInvalidInput(msg)


def _check_encoding(
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _blockm(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(
module="contour", args=build_arg_list(kwargs, infile=vintbl)
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/nearneighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def nearneighbor(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
19 changes: 9 additions & 10 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,26 +209,27 @@ def plot( # noqa: PLR0912
kwargs = self._preprocess(**kwargs)

kind = data_kind(data)
extra_arrays = []
if kind == "none": # Add more columns for vectors input
if kind == "none": # Vectors input
data = {"x": x, "y": y}
x, y = None, None
# Parameters for vector styles
if (
kwargs.get("S") is not None
and kwargs["S"][0] in "vV"
and is_nonstr_iter(direction)
):
extra_arrays.extend(direction)
data.update({"x2": direction[0], "y2": direction[1]})
# Fill
if is_nonstr_iter(kwargs.get("G")):
extra_arrays.append(kwargs.get("G"))
data["fill"] = kwargs["G"]
del kwargs["G"]
# Size
if is_nonstr_iter(size):
extra_arrays.append(size)
data["size"] = size
# Intensity and transparency
for flag in ["I", "t"]:
for flag, name in ["I", "intensity"], ["t", "transparency"]:
if is_nonstr_iter(kwargs.get(flag)):
extra_arrays.append(kwargs.get(flag))
data[name] = kwargs[flag]
kwargs[flag] = ""
else:
for name, value in [
Expand All @@ -255,7 +256,5 @@ def plot( # noqa: PLR0912
pass

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays
) as vintbl:
with lib.virtualfile_in(check_kind="vector", data=data, x=x, y=y) as vintbl:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))
24 changes: 9 additions & 15 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,27 +184,27 @@ def plot3d( # noqa: PLR0912
kwargs = self._preprocess(**kwargs)

kind = data_kind(data)
extra_arrays = []

if kind == "none": # Add more columns for vectors input
if kind == "none": # Vectors input
data = {"x": x, "y": y, "z": z}
x, y, z = None, None, None
# Parameters for vector styles
if (
kwargs.get("S") is not None
and kwargs["S"][0] in "vV"
and is_nonstr_iter(direction)
):
extra_arrays.extend(direction)
data.update({"x2": direction[0], "y2": direction[1]})
# Fill
if is_nonstr_iter(kwargs.get("G")):
extra_arrays.append(kwargs.get("G"))
data["fill"] = kwargs["G"]
del kwargs["G"]
# Size
if is_nonstr_iter(size):
extra_arrays.append(size)
data["size"] = size
# Intensity and transparency
for flag in ["I", "t"]:
for flag, name in [("I", "intensity"), ("t", "transparency")]:
if is_nonstr_iter(kwargs.get(flag)):
extra_arrays.append(kwargs.get(flag))
data[name] = kwargs[flag]
kwargs[flag] = ""
else:
for name, value in [
Expand Down Expand Up @@ -232,12 +232,6 @@ def plot3d( # noqa: PLR0912

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector",
data=data,
x=x,
y=y,
z=z,
extra_arrays=extra_arrays,
required_z=True,
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))
2 changes: 1 addition & 1 deletion pygmt/src/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def project(
x=x,
y=y,
z=z,
required_z=False,
required_cols=2,
required_data=False,
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def surface(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def text_( # noqa: PLR0912
confdict = {"PS_CHAR_ENCODING": encoding}

with Session() as lib:
with lib.virtualfile_in(
with lib.virtualfile_from_data(
check_kind="vector", data=textfiles, x=x, y=y, extra_arrays=extra_arrays
) as vintbl:
lib.call_module(
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/triangulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def regular_grid(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down Expand Up @@ -238,7 +238,7 @@ def delaunay_triples(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/wiggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ def wiggle(

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl))
2 changes: 1 addition & 1 deletion pygmt/src/xyz2grd.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
Loading
Loading