diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index 1e5df5ffbda..d71300fa3ba 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -2,15 +2,18 @@ grdtrack - Sample grids at specified (x,y) locations. """ +from typing import Literal + +import numpy as np import pandas as pd from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, kwargs_to_strings, use_alias, + validate_output_table_type, ) __doctest_skip__ = ["grdtrack"] @@ -44,7 +47,14 @@ w="wrap", ) @kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma") -def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): +def grdtrack( + grid, + points=None, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + newcolname=None, + **kwargs, +) -> pd.DataFrame | np.ndarray | None: r""" Sample grids at specified (x,y) locations. @@ -73,15 +83,12 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): points : str, {table-like} Pass in either a file name to an ASCII data table, a 2-D {table-classes}. - + {output_type} + {outfile} newcolname : str Required if ``points`` is a :class:`pandas.DataFrame`. The name for the new column in the track :class:`pandas.DataFrame` table where the sampled values will be placed. - - outfile : str - The file name for the output ASCII file. - resample : str **f**\|\ **p**\|\ **m**\|\ **r**\|\ **R**\ [**+l**] For track resampling (if ``crossprofile`` or ``profile`` are set) we @@ -258,13 +265,13 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): Returns ------- - track: pandas.DataFrame or None - Return type depends on whether the ``outfile`` parameter is set: + ret + Return type depends on ``outfile`` and ``output_type``: - - :class:`pandas.DataFrame` table with (x, y, ..., newcolname) if - ``outfile`` is not set - - None if ``outfile`` is set (track output will be stored in file set - by ``outfile``) + - ``None`` if ``outfile`` is set (output will be stored in file set by + ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) Example ------- @@ -291,30 +298,27 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): if hasattr(points, "columns") and newcolname is None: raise GMTInvalidInput("Please pass in a str to 'newcolname'") - with GMTTempFile(suffix=".csv") as tmpfile: - with Session() as lib: - with ( - lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, - lib.virtualfile_in( - check_kind="vector", data=points, required_data=False - ) as vintbl, - ): - kwargs["G"] = vingrd - if outfile is None: # Output to tmpfile if outfile is not set - outfile = tmpfile.name - lib.call_module( - module="grdtrack", - args=build_arg_string(kwargs, infile=vintbl, outfile=outfile), - ) + output_type = validate_output_table_type(output_type, outfile=outfile) - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - try: - column_names = [*points.columns.to_list(), newcolname] - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) - except AttributeError: # 'str' object has no attribute 'columns' - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None + column_names = None + if output_type == "pandas" and isinstance(points, pd.DataFrame): + column_names = [*points.columns.to_list(), newcolname] - return result + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_in( + check_kind="vector", data=points, required_data=False + ) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + kwargs["G"] = vingrd + lib.call_module( + module="grdtrack", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset( + output_type=output_type, + vfname=vouttbl, + column_names=column_names, + )