Skip to content

Commit

Permalink
pygmt.grdtrack: Add 'output_type' parameter for output in pandas/nump…
Browse files Browse the repository at this point in the history
…y/file formats (#3106)
  • Loading branch information
seisman committed Mar 15, 2024
1 parent 83b1a12 commit 8220180
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
grdtrack - Sample grids at specified (x,y) locations.
"""

from typing import Literal

import numpy as np
import pandas as pd
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
validate_output_table_type,
)

__doctest_skip__ = ["grdtrack"]
Expand Down Expand Up @@ -44,7 +47,14 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
def grdtrack(
grid,
points=None,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
newcolname=None,
**kwargs,
) -> pd.DataFrame | np.ndarray | None:
r"""
Sample grids at specified (x,y) locations.
Expand Down Expand Up @@ -73,15 +83,12 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
points : str, {table-like}
Pass in either a file name to an ASCII data table, a 2-D
{table-classes}.
{output_type}
{outfile}
newcolname : str
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
new column in the track :class:`pandas.DataFrame` table where the
sampled values will be placed.
outfile : str
The file name for the output ASCII file.
resample : str
**f**\|\ **p**\|\ **m**\|\ **r**\|\ **R**\ [**+l**]
For track resampling (if ``crossprofile`` or ``profile`` are set) we
Expand Down Expand Up @@ -258,13 +265,13 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
Returns
-------
track: pandas.DataFrame or None
Return type depends on whether the ``outfile`` parameter is set:
ret
Return type depends on ``outfile`` and ``output_type``:
- :class:`pandas.DataFrame` table with (x, y, ..., newcolname) if
``outfile`` is not set
- None if ``outfile`` is set (track output will be stored in file set
by ``outfile``)
- ``None`` if ``outfile`` is set (output will be stored in file set by
``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
Example
-------
Expand All @@ -291,30 +298,27 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
if hasattr(points, "columns") and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_in(
check_kind="vector", data=points, required_data=False
) as vintbl,
):
kwargs["G"] = vingrd
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
)
output_type = validate_output_table_type(output_type, outfile=outfile)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
try:
column_names = [*points.columns.to_list(), newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
except AttributeError: # 'str' object has no attribute 'columns'
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None
column_names = None
if output_type == "pandas" and isinstance(points, pd.DataFrame):
column_names = [*points.columns.to_list(), newcolname]

return result
with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_in(
check_kind="vector", data=points, required_data=False
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
kwargs["G"] = vingrd
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(
output_type=output_type,
vfname=vouttbl,
column_names=column_names,
)

0 comments on commit 8220180

Please sign in to comment.