Skip to content

Commit

Permalink
Change ncols to names
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Oct 15, 2023
1 parent f849e5a commit f37413b
Show file tree
Hide file tree
Showing 15 changed files with 75 additions and 53 deletions.
9 changes: 5 additions & 4 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,7 @@ def virtualfile_from_data(
check_kind=None,
data=None,
vectors=None,
ncols=2,
names=["x", "y"],
required_data=True,
):
"""
Expand All @@ -1498,8 +1498,9 @@ def virtualfile_from_data(
vectors : list of 1-D arrays or None
A list of 1-D arrays. Each array will be a column in the table.
All of these arrays must be of the same size.
ncols : int
The minimum number of columns required for the data.
names : list of str
A list of names for each of the columns. Must be of the same size
as the number of vectors. Default is ``["x", "y"]``.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
Expand Down Expand Up @@ -1537,7 +1538,7 @@ def virtualfile_from_data(
validate_data_input(
data=data,
vectors=vectors,
ncols=ncols,
names=names,
required_data=required_data,
kind=kind,
)
Expand Down
64 changes: 41 additions & 23 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,71 @@


def validate_data_input(
data=None, vectors=None, ncols=2, required_data=True, kind=None
data=None, vectors=None, names=["x", "y"], required_data=True, kind=None
):
"""
Check if the data input is valid.
Parameters
----------
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
table, an :class:`xarray.DataArray`, a 1-D/2-D
{table-classes} or an option argument.
vectors : list of 1-D arrays
A list of 1-D arrays with the data columns.
names : list of str
List of column names.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
kind : str or None
The kind of data that will be passed to a module. If not given, it
will be determined by calling :func:`data_kind`.
Examples
--------
>>> validate_data_input(data="infile")
>>> validate_data_input(vectors=[[1, 2, 3], [4, 5, 6]], ncols=2)
>>> validate_data_input(vectors=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], ncols=3)
>>> validate_data_input(vectors=[[1, 2, 3], [4, 5, 6]], names="xy")
>>> validate_data_input(
... vectors=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], names="xyz"
... )
>>> validate_data_input(data=None, required_data=False)
>>> validate_data_input()
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: No input data provided.
>>> validate_data_input(vectors=[[1, 2, 3], None], ncols=2)
>>> validate_data_input(vectors=[[1, 2, 3], None], names="xy")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: The 'y' column can't be None.
>>> validate_data_input(vectors=[None, [4, 5, 6]], ncols=2)
>>> validate_data_input(vectors=[None, [4, 5, 6]], names="xy")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: The 'x' column can't be None.
>>> validate_data_input(vectors=[[1, 2, 3], [4, 5, 6], None], ncols=3)
>>> validate_data_input(vectors=[[1, 2, 3], [4, 5, 6], None], names="xyz")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: The 'z' column can't be None.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> validate_data_input(data=data, ncols=3, kind="matrix")
>>> validate_data_input(data=data, names="xyz", kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must have at least 3 columns.
>>> validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... ncols=3,
... names="xyz",
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must have at least 3 columns.
>>> validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... ncols=3,
... names="xyz",
... kind="matrix",
... )
Traceback (most recent call last):
Expand Down Expand Up @@ -91,28 +110,27 @@ def validate_data_input(
if kind == "vectors": # From data_kind, we know that data is None
if vectors is None:
raise GMTInvalidInput("No input data provided.")
if len(vectors) < ncols:
if len(vectors) < len(names):
raise GMTInvalidInput(
f"Requires {ncols} 1-D arrays but got {len(vectors)}."
f"Requires {len(names)} 1-D arrays but got {len(vectors)}."
)
for i, v in enumerate(vectors[:ncols]):
for i, v in enumerate(vectors[: len(names)]):
if v is None:
if i < 3:
msg = f"The '{'xyz'[i]}' column can't be None."
else:
msg = "Column {i} can't be None."
raise GMTInvalidInput(msg)
raise GMTInvalidInput(f"Column {i} ('{names[i]}') can't be None.")
else:
if vectors is not None and any(v is not None for v in vectors):
raise GMTInvalidInput("Too much data. Pass in either 'data' or 1-D arrays.")
if kind == "matrix": # check number of columns for matrix-like data
msg = f"data must have at least {len(names)} columns.\n" + " ".join(names)
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
if len(data.shape) == 1 and data.shape[0] < ncols:
raise GMTInvalidInput(f"data must have at least {ncols} columns.")
if len(data.shape) > 1 and data.shape[1] < ncols:
raise GMTInvalidInput(f"data must have at least {ncols} columns.")
if hasattr(data, "data_vars") and len(data.data_vars) < ncols: # xr.Dataset
raise GMTInvalidInput(f"data must have at least {ncols} columns.")
if len(data.shape) == 1 and data.shape[0] < len(names):
raise GMTInvalidInput(msg)
if len(data.shape) > 1 and data.shape[1] < len(names):
raise GMTInvalidInput(msg)
if hasattr(data, "data_vars") and len(data.data_vars) < len(
names
): # xr.Dataset
raise GMTInvalidInput(msg)


def data_kind(data=None, required=True):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _blockm(block_method, data, x, y, z, outfile, **kwargs):
with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
table_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
# Run blockm* on data table
with table_context as infile:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
with file_context as fname:
lib.call_module(
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/nearneighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def nearneighbor(data=None, x=None, y=None, z=None, **kwargs):
with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
table_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
with table_context as infile:
if (outgrid := kwargs.get("G")) is None:
Expand Down
12 changes: 6 additions & 6 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):

kind = data_kind(data)
vectors = [x, y]
ncols = 2
names = ["x", "y"]

if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
vectors.extend(direction)
ncols += 2
names.extend(["x2", "y2"])
elif (
kwargs.get("S") is None
and kind == "geojson"
Expand All @@ -242,15 +242,15 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
"Can't use arrays for fill if data is matrix or file."
)
vectors.append(kwargs["G"])
ncols += 1
names.append("fill")
del kwargs["G"]
if size is not None:
if kind != "vectors":
raise GMTInvalidInput(
"Can't use arrays for 'size' if data is a matrix or file."
)
vectors.append(size)
ncols += 1
names.append("size")

for flag in ["I", "t"]:
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
Expand All @@ -259,12 +259,12 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
f"Can't use arrays for {plot.aliases[flag]} if data is matrix or file."
)
vectors.append(kwargs[flag])
ncols += 1
names.append(plot.aliases[flag])
kwargs[flag] = ""

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=vectors, ncols=ncols
check_kind="vector", data=data, vectors=vectors, names=names
)

with file_context as fname:
Expand Down
12 changes: 6 additions & 6 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def plot3d(

kind = data_kind(data)
vectors = [x, y, z]
ncols = 3
names = ["x", "y", "z"]

if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
vectors.extend(direction)
ncols += 2
names.extend(["x2", "y2"])
elif (
kwargs.get("S") is None
and kind == "geojson"
Expand All @@ -212,15 +212,15 @@ def plot3d(
"Can't use arrays for fill if data is matrix or file."
)
vectors.append(kwargs["G"])
ncols += 1
names.append("fill")
del kwargs["G"]
if size is not None:
if kind != "vectors":
raise GMTInvalidInput(
"Can't use arrays for 'size' if data is a matrix or a file."
)
ncols += 1
vectors.append(size)
names.append("size")

for flag in ["I", "t"]:
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
Expand All @@ -229,12 +229,12 @@ def plot3d(
f"Can't use arrays for {plot3d.aliases[flag]} if data is matrix or file."
)
vectors.append(kwargs[flag])
ncols += 1
names.append(plot3d.aliases[flag])
kwargs[flag] = ""

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=vectors, ncols=ncols
check_kind="vector", data=data, vectors=vectors, names=names
)

with file_context as fname:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
with Session() as lib:
if kwargs.get("G") is None:
table_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)

# Run project on the temporary (csv) data table
Expand Down
5 changes: 4 additions & 1 deletion pygmt/src/rose.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def rose(self, data=None, length=None, azimuth=None, **kwargs):

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[length, azimuth], ncols=2
check_kind="vector",
data=data,
vectors=[length, azimuth],
names=["length", "azimuth"],
)

with file_context as fname:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/sphdistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def sphdistance(data=None, x=None, y=None, **kwargs):
with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y], ncols=2
check_kind="vector", data=data, vectors=[x, y], names="xy"
)
with file_context as infile:
if (outgrid := kwargs.get("G")) is None:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def surface(data=None, x=None, y=None, z=None, **kwargs):
with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
with file_context as infile:
if (outgrid := kwargs.get("G")) is None:
Expand Down
8 changes: 4 additions & 4 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,24 @@ def text_(
kwargs["F"] += f"+c{position}+t{text}"

vectors = [x, y]
ncols = 2
names = ["x", "y"]
# If an array of transparency is given, GMT will read it from
# the last numerical column per data record.
if kwargs.get("t") is not None and is_nonstr_iter(kwargs["t"]):
vectors.append(kwargs["t"])
kwargs["t"] = ""
ncols += 1
names.append("transparency")

# Append text at last column. Text must be passed in as str type.
if kind == "vectors":
vectors.append(
np.vectorize(non_ascii_to_octal)(np.atleast_1d(text).astype(str))
)
ncols += 1
names.append("text")

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=textfiles, vectors=vectors, ncols=ncols
check_kind="vector", data=textfiles, vectors=vectors, names=names
)
with file_context as fname:
lib.call_module(module="text", args=build_arg_string(kwargs, infile=fname))
2 changes: 1 addition & 1 deletion pygmt/src/triangulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _triangulate(
"""
with Session() as lib:
table_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
with table_context as infile:
# table output if outgrid is unset, else output to outgrid
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/wiggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def wiggle(

with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)

with file_context as fname:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/xyz2grd.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, **kwargs):
with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
file_context = lib.virtualfile_from_data(
check_kind="vector", data=data, vectors=[x, y, z], ncols=3
check_kind="vector", data=data, vectors=[x, y, z], names="xyz"
)
with file_context as infile:
if (outgrid := kwargs.get("G")) is None:
Expand Down

0 comments on commit f37413b

Please sign in to comment.