Skip to content

Commit

Permalink
pygmt.select: Improve performance by storing output in virtual files
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Mar 13, 2024
1 parent 3a507a8 commit c79b749
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions pygmt/src/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
select - Select data table subsets based on multiple spatial criteria.
"""

from typing import Literal

import numpy as np
import pandas as pd
from pygmt.clib import Session
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
validate_output_table_type,
)

__doctest_skip__ = ["select"]
Expand Down Expand Up @@ -41,7 +44,12 @@
w="wrap",
)
@kwargs_to_strings(M="sequence", R="sequence", i="sequence_comma", o="sequence_comma")
def select(data=None, outfile=None, **kwargs):
def select(
data=None,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
) -> pd.DataFrame | np.ndarray | None:
r"""
Select data table subsets based on multiple spatial criteria.
Expand Down Expand Up @@ -70,8 +78,8 @@ def select(data=None, outfile=None, **kwargs):
data : str, {table-like}
Pass in either a file name to an ASCII data table, a 2-D
{table-classes}.
outfile : str
The file name for the output ASCII file.
{output_type}
{outfile}
{area_thresh}
dist2pt : str
*pointfile*\|\ *lon*/*lat*\ **+d**\ *dist*.
Expand Down Expand Up @@ -180,12 +188,12 @@ def select(data=None, outfile=None, **kwargs):
Returns
-------
output : pandas.DataFrame or None
Return type depends on whether the ``outfile`` parameter is set:
ret
Return type depends on ``outfile`` and ``output_type``:
- :class:`pandas.DataFrame` table if ``outfile`` is not set.
- None if ``outfile`` is set (filtered output will be stored in file
set by ``outfile``).
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
Example
-------
Expand All @@ -196,25 +204,23 @@ def select(data=None, outfile=None, **kwargs):
>>> # longitudes 246 and 247 and latitudes 20 and 21
>>> out = pygmt.select(data=ship_data, region=[246, 247, 20, 21])
"""

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="select",
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
try:
column_names = data.columns.to_list()
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
except AttributeError: # 'str' object has no attribute 'columns'
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

return result
output_type = validate_output_table_type(output_type, outfile=outfile)

column_names = None
if output_type == "pandas" and isinstance(data, pd.DataFrame):
column_names = data.columns.to_list()

with Session() as lib:
with (
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
lib.call_module(
module="select",
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(
output_type=output_type,
vfname=vouttbl,
column_names=column_names,
)

0 comments on commit c79b749

Please sign in to comment.