From bcd1e404dd8e97a884f40ce2c6190aac6b20ef50 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 24 Jul 2024 15:04:02 +0800 Subject: [PATCH 1/4] Session.virtualfile_in: Remove the extra_arrays parameter --- pygmt/clib/session.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 7e74219937f..fa7e4b5691b 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1605,7 +1605,6 @@ def virtualfile_in( # noqa: PLR0912 x=None, y=None, z=None, - extra_arrays=None, required_z=False, required_data=True, ): @@ -1627,9 +1626,6 @@ def virtualfile_in( # noqa: PLR0912 data input. x/y/z : 1-D arrays or None x, y, and z columns as numpy arrays. - extra_arrays : list of 1-D arrays - Optional. A list of numpy arrays in addition to x, y, and z. - All of these arrays must be of the same size as the x/y/z arrays. required_z : bool State whether the 'z' column is required. required_data : bool @@ -1723,8 +1719,6 @@ def virtualfile_in( # noqa: PLR0912 _data = [np.atleast_1d(x), np.atleast_1d(y)] if z is not None: _data.append(np.atleast_1d(z)) - if extra_arrays: - _data.extend(extra_arrays) case "vectors": if hasattr(data, "items") and not hasattr(data, "to_frame"): # Dict, pandas.DataFrame or xarray.Dataset types. @@ -1757,18 +1751,30 @@ def virtualfile_from_data( instead. """ msg = ( - "API function 'Session.virtualfile_from_datae()' has been deprecated since " + "API function 'Session.virtualfile_from_data()' has been deprecated since " "v0.13.0 and will be removed in v0.15.0. Use 'Session.virtualfile_in()' " "instead." ) warnings.warn(msg, category=FutureWarning, stacklevel=2) + # Session.virtualfile_in no longer has the 'extra_arrays' parameter. + if data is None and extra_arrays is not None: + data = [np.atleast_1d(x), np.atleast_1d(y)] + if z is not None: + data.append(np.atleast_1d(z)) + data.extend(extra_arrays) + x, y, z = None, None, None + + # Need to convert the list of arrays into a pandas.DataFrame object. + # Otherwise, the "vector" `data` will be converted to a homogeneous 2D + # numpy.ndarray first. + data = pd.concat(objs=[pd.Series(array) for array in data], axis="columns") + return self.virtualfile_in( check_kind=check_kind, data=data, x=x, y=y, z=z, - extra_arrays=extra_arrays, required_z=required_z, required_data=required_data, ) From 4391db2ba9945051009329b8e992cc0451ee0f0c Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 24 Jul 2024 15:06:49 +0800 Subject: [PATCH 2/4] Refactor Figure.plot and Figure.plot3d to avoid using the extra_arrays parameter --- pygmt/src/plot.py | 18 ++++++++---------- pygmt/src/plot3d.py | 23 ++++++++--------------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index f313c37822b..76d0314c609 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -209,26 +209,26 @@ def plot( # noqa: PLR0912 kwargs = self._preprocess(**kwargs) kind = data_kind(data) - extra_arrays = [] - if kind == "none": # Add more columns for vectors input + if kind == "none": # Vectors input + data = {"x": x, "y": y} # Parameters for vector styles if ( kwargs.get("S") is not None and kwargs["S"][0] in "vV" and is_nonstr_iter(direction) ): - extra_arrays.extend(direction) + data.update({"x2": direction[0], "y2": direction[1]}) # Fill if is_nonstr_iter(kwargs.get("G")): - extra_arrays.append(kwargs.get("G")) + data["fill"] = kwargs["G"] del kwargs["G"] # Size if is_nonstr_iter(size): - extra_arrays.append(size) + data["size"] = size # Intensity and transparency - for flag in ["I", "t"]: + for flag, name in ["I", "intensity"], ["t", "transparency"]: if is_nonstr_iter(kwargs.get(flag)): - extra_arrays.append(kwargs.get(flag)) + data[name] = kwargs[flag] kwargs[flag] = "" else: for name, value in [ @@ -255,7 +255,5 @@ def plot( # noqa: PLR0912 pass with Session() as lib: - with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, extra_arrays=extra_arrays - ) as vintbl: + with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 704a44a4f7f..c0285a4c7cb 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -184,27 +184,26 @@ def plot3d( # noqa: PLR0912 kwargs = self._preprocess(**kwargs) kind = data_kind(data) - extra_arrays = [] - - if kind == "none": # Add more columns for vectors input + if kind == "none": # Vectors input + data = {"x": x, "y": y, "z": z} # Parameters for vector styles if ( kwargs.get("S") is not None and kwargs["S"][0] in "vV" and is_nonstr_iter(direction) ): - extra_arrays.extend(direction) + data.update({"x2": direction[0], "y2": direction[1]}) # Fill if is_nonstr_iter(kwargs.get("G")): - extra_arrays.append(kwargs.get("G")) + data["fill"] = kwargs["G"] del kwargs["G"] # Size if is_nonstr_iter(size): - extra_arrays.append(size) + data["size"] = size # Intensity and transparency - for flag in ["I", "t"]: + for flag, name in [("I", "intensity"), ("t", "transparency")]: if is_nonstr_iter(kwargs.get(flag)): - extra_arrays.append(kwargs.get(flag)) + data[name] = kwargs[flag] kwargs[flag] = "" else: for name, value in [ @@ -232,12 +231,6 @@ def plot3d( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( - check_kind="vector", - data=data, - x=x, - y=y, - z=z, - extra_arrays=extra_arrays, - required_z=True, + check_kind="vector", data=data, required_z=True ) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) From 9e78da059d26db87c000f7422fa0c0c2dcae9381 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sun, 4 Aug 2024 19:01:22 +0800 Subject: [PATCH 3/4] Temporary switch back to virtualfile_from_data in Figure.text --- pygmt/src/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/src/text.py b/pygmt/src/text.py index 994856b7250..a7c6bc1dd7f 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -240,7 +240,7 @@ def text_( # noqa: PLR0912 confdict = {"PS_CHAR_ENCODING": encoding} with Session() as lib: - with lib.virtualfile_in( + with lib.virtualfile_from_data( check_kind="vector", data=textfiles, x=x, y=y, extra_arrays=extra_arrays ) as vintbl: lib.call_module( From b33476c6c9b5b9de36b7d97584cf36d8dd3a13cf Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sun, 4 Aug 2024 20:19:16 +0800 Subject: [PATCH 4/4] Add new parameter 'required_cols' and remove the parameter 'required_z' --- pygmt/clib/session.py | 10 ++-- pygmt/helpers/utils.py | 67 ++++++++++++++++----------- pygmt/src/blockm.py | 2 +- pygmt/src/contour.py | 2 +- pygmt/src/nearneighbor.py | 2 +- pygmt/src/plot.py | 3 +- pygmt/src/plot3d.py | 3 +- pygmt/src/project.py | 2 +- pygmt/src/surface.py | 2 +- pygmt/src/triangulate.py | 4 +- pygmt/src/wiggle.py | 2 +- pygmt/src/xyz2grd.py | 2 +- pygmt/tests/test_clib_virtualfiles.py | 6 +-- 13 files changed, 60 insertions(+), 47 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index fa7e4b5691b..a25eda2b490 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1605,8 +1605,8 @@ def virtualfile_in( # noqa: PLR0912 x=None, y=None, z=None, - required_z=False, required_data=True, + required_cols: int = 2, ): """ Store any data inside a virtual file. @@ -1626,11 +1626,11 @@ def virtualfile_in( # noqa: PLR0912 data input. x/y/z : 1-D arrays or None x, y, and z columns as numpy arrays. - required_z : bool - State whether the 'z' column is required. required_data : bool Set to True when 'data' is required, or False when dealing with optional virtual files. [Default is True]. + required_cols + Number of required columns. Returns ------- @@ -1664,8 +1664,8 @@ def virtualfile_in( # noqa: PLR0912 x=x, y=y, z=z, - required_z=required_z, required_data=required_data, + required_cols=required_cols, kind=kind, ) @@ -1775,8 +1775,8 @@ def virtualfile_from_data( x=x, y=y, z=z, - required_z=required_z, required_data=required_data, + required_cols=3 if required_z else 2, ) @contextlib.contextmanager diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index e9d819266b8..f6be18bfda3 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -21,7 +21,7 @@ def _validate_data_input( - data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None + data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None ): """ Check if the combination of data/x/y/z is valid. @@ -44,7 +44,7 @@ def _validate_data_input( Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True) + >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3) Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. @@ -52,26 +52,25 @@ def _validate_data_input( >>> import pandas as pd >>> import xarray as xr >>> data = np.arange(8).reshape((4, 2)) - >>> _validate_data_input(data=data, required_z=True, kind="matrix") + >>> _validate_data_input(data=data, required_cols=3, kind="matrix") Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), - ... required_z=True, + ... required_cols=3, ... kind="matrix", ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), - ... required_z=True, - ... kind="matrix", + ... required_cols=3, ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given. >>> _validate_data_input(data="infile", x=[1, 2, 3]) Traceback (most recent call last): ... @@ -94,26 +93,38 @@ def _validate_data_input( GMTInvalidInput If the data input is not valid. """ - if data is None: # data is None - if x is None and y is None: # both x and y are None - if required_data: # data is not optional + if kind is None: + kind = data_kind(data, required=required_data) + + if data is not None and any(v is not None for v in (x, y, z)): + raise GMTInvalidInput("Too much data. Use either data or x/y/z.") + + match kind: + case "none": + if x is None and y is None: # both x and y are None raise GMTInvalidInput("No input data provided.") - elif x is None or y is None: # either x or y is None - raise GMTInvalidInput("Must provide both x and y.") - if required_z and z is None: # both x and y are not None, now check z - raise GMTInvalidInput("Must provide x, y, and z.") - else: # data is not None - if x is not None or y is not None or z is not None: - raise GMTInvalidInput("Too much data. Use either data or x/y/z.") - # For 'matrix' kind, check if data has the required z column - if kind == "matrix" and required_z: - if hasattr(data, "shape"): # np.ndarray or pd.DataFrame - if len(data.shape) == 1 and data.shape[0] < 3: - raise GMTInvalidInput("data must provide x, y, and z columns.") - if len(data.shape) > 1 and data.shape[1] < 3: - raise GMTInvalidInput("data must provide x, y, and z columns.") - if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset - raise GMTInvalidInput("data must provide x, y, and z columns.") + if x is None or y is None: # either x or y is None + raise GMTInvalidInput("Must provide both x and y.") + if required_cols >= 3 and z is None: + # both x and y are not None, now check z + raise GMTInvalidInput("Must provide x, y, and z.") + case "matrix": # 2-D numpy.ndarray + if (actual_cols := data.shape[1]) < required_cols: + msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given." + raise GMTInvalidInput(msg) + case "vectors": + if hasattr(data, "items") and not hasattr(data, "to_frame"): + # Dict, pd.DataFrame, xr.Dataset + arrays = [array for _, array in data.items()] + if (actual_cols := len(arrays)) < required_cols: + msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given." + raise GMTInvalidInput(msg) + + # Loop over columns to make sure they're not None + for idx, array in enumerate(arrays[:required_cols]): + if array is None: + msg = f"data needs {required_cols} columns but the {idx} column is None." + raise GMTInvalidInput(msg) def _check_encoding( diff --git a/pygmt/src/blockm.py b/pygmt/src/blockm.py index a8b35d6c942..bf4fc1299c3 100644 --- a/pygmt/src/blockm.py +++ b/pygmt/src/blockm.py @@ -55,7 +55,7 @@ def _blockm( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/contour.py b/pygmt/src/contour.py index c5aa26a3b10..f0af0549fe3 100644 --- a/pygmt/src/contour.py +++ b/pygmt/src/contour.py @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs): with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl: lib.call_module( module="contour", args=build_arg_list(kwargs, infile=vintbl) diff --git a/pygmt/src/nearneighbor.py b/pygmt/src/nearneighbor.py index 7027e04a358..ebafceb8f67 100644 --- a/pygmt/src/nearneighbor.py +++ b/pygmt/src/nearneighbor.py @@ -140,7 +140,7 @@ def nearneighbor( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 76d0314c609..b8eb64ff8ac 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -211,6 +211,7 @@ def plot( # noqa: PLR0912 kind = data_kind(data) if kind == "none": # Vectors input data = {"x": x, "y": y} + x, y = None, None # Parameters for vector styles if ( kwargs.get("S") is not None @@ -255,5 +256,5 @@ def plot( # noqa: PLR0912 pass with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: + with lib.virtualfile_in(check_kind="vector", data=data, x=x, y=y) as vintbl: lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index c0285a4c7cb..5b59db48172 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -186,6 +186,7 @@ def plot3d( # noqa: PLR0912 kind = data_kind(data) if kind == "none": # Vectors input data = {"x": x, "y": y, "z": z} + x, y, z = None, None, None # Parameters for vector styles if ( kwargs.get("S") is not None @@ -231,6 +232,6 @@ def plot3d( # noqa: PLR0912 with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl: lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/project.py b/pygmt/src/project.py index 811a7d48158..2de3a18e8f0 100644 --- a/pygmt/src/project.py +++ b/pygmt/src/project.py @@ -246,7 +246,7 @@ def project( x=x, y=y, z=z, - required_z=False, + required_cols=2, required_data=False, ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, diff --git a/pygmt/src/surface.py b/pygmt/src/surface.py index 23fdbdb353d..bdae1b1b14e 100644 --- a/pygmt/src/surface.py +++ b/pygmt/src/surface.py @@ -153,7 +153,7 @@ def surface(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/src/triangulate.py b/pygmt/src/triangulate.py index 1765bd1d28e..6adfe08ddba 100644 --- a/pygmt/src/triangulate.py +++ b/pygmt/src/triangulate.py @@ -138,7 +138,7 @@ def regular_grid( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=False + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): @@ -238,7 +238,7 @@ def delaunay_triples( with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=False + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2 ) as vintbl, lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, ): diff --git a/pygmt/src/wiggle.py b/pygmt/src/wiggle.py index 921c5317349..105eada44b5 100644 --- a/pygmt/src/wiggle.py +++ b/pygmt/src/wiggle.py @@ -108,6 +108,6 @@ def wiggle( with Session() as lib: with lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl: lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl)) diff --git a/pygmt/src/xyz2grd.py b/pygmt/src/xyz2grd.py index 2eedfb62e83..eeaf308729b 100644 --- a/pygmt/src/xyz2grd.py +++ b/pygmt/src/xyz2grd.py @@ -145,7 +145,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa with Session() as lib: with ( lib.virtualfile_in( - check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3 ) as vintbl, lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd, ): diff --git a/pygmt/tests/test_clib_virtualfiles.py b/pygmt/tests/test_clib_virtualfiles.py index b8b5ee0500d..2a6cbdcab6a 100644 --- a/pygmt/tests/test_clib_virtualfiles.py +++ b/pygmt/tests/test_clib_virtualfiles.py @@ -141,7 +141,7 @@ def test_virtualfile_in_required_z_matrix(array_func, kind): data = array_func(dataframe) with clib.Session() as lib: with lib.virtualfile_in( - data=data, required_z=True, check_kind="vector" + data=data, required_cols=3, check_kind="vector" ) as vfile: with GMTTempFile() as outfile: lib.call_module("info", [vfile, f"->{outfile.name}"]) @@ -163,7 +163,7 @@ def test_virtualfile_in_required_z_matrix_missing(): data = np.ones((5, 2)) with clib.Session() as lib: with pytest.raises(GMTInvalidInput): - with lib.virtualfile_in(data=data, required_z=True, check_kind="vector"): + with lib.virtualfile_in(data=data, required_cols=3, check_kind="vector"): pass @@ -190,7 +190,7 @@ def test_virtualfile_in_fail_non_valid_data(data): with clib.Session() as lib: with pytest.raises(GMTInvalidInput): lib.virtualfile_in( - x=variable[0], y=variable[1], z=variable[2], required_z=True + x=variable[0], y=variable[1], z=variable[2], required_cols=3 ) # Should also fail if given too much data