Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pygmt.grdtrack: Add 'output_type' parameter for output in pandas/numpy/file formats #3106

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
grdtrack - Sample grids at specified (x,y) locations.
"""

from typing import Literal

import numpy as np
import pandas as pd
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
validate_output_table_type,
)

__doctest_skip__ = ["grdtrack"]
Expand Down Expand Up @@ -44,7 +47,14 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
def grdtrack(
grid,
points=None,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
newcolname=None,
**kwargs,
) -> pd.DataFrame | np.ndarray | None:
r"""
Sample grids at specified (x,y) locations.

Expand Down Expand Up @@ -73,15 +83,12 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
points : str, {table-like}
Pass in either a file name to an ASCII data table, a 2-D
{table-classes}.

{output_type}
{outfile}
newcolname : str
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
new column in the track :class:`pandas.DataFrame` table where the
sampled values will be placed.

outfile : str
The file name for the output ASCII file.

resample : str
**f**\|\ **p**\|\ **m**\|\ **r**\|\ **R**\ [**+l**]
For track resampling (if ``crossprofile`` or ``profile`` are set) we
Expand Down Expand Up @@ -258,13 +265,12 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):

Returns
-------
track: pandas.DataFrame or None
Return type depends on whether the ``outfile`` parameter is set:
ret
Return type depends on ``outfile`` and ``output_type``:

- :class:`pandas.DataFrame` table with (x, y, ..., newcolname) if
``outfile`` is not set
- None if ``outfile`` is set (track output will be stored in file set
by ``outfile``)
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
- ``None`` if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)


Example
-------
Expand All @@ -291,30 +297,27 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
if hasattr(points, "columns") and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_in(
check_kind="vector", data=points, required_data=False
) as vintbl,
):
kwargs["G"] = vingrd
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
)
output_type = validate_output_table_type(output_type, outfile=outfile)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
try:
column_names = [*points.columns.to_list(), newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
except AttributeError: # 'str' object has no attribute 'columns'
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None
column_names = None
if output_type == "pandas" and isinstance(points, pd.DataFrame):
column_names = [*points.columns.to_list(), newcolname]

return result
with Session() as lib:
with (
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
lib.virtualfile_in(
check_kind="vector", data=points, required_data=False
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
kwargs["G"] = vingrd
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(
output_type=output_type,
vfname=vouttbl,
column_names=column_names,
)
Loading