diff --git a/pygmt/src/select.py b/pygmt/src/select.py index fe132b356f9..b44da01a9eb 100644 --- a/pygmt/src/select.py +++ b/pygmt/src/select.py @@ -2,14 +2,17 @@ select - Select data table subsets based on multiple spatial criteria. """ +from typing import Literal + +import numpy as np import pandas as pd from pygmt.clib import Session from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, kwargs_to_strings, use_alias, + validate_output_table_type, ) __doctest_skip__ = ["select"] @@ -41,7 +44,12 @@ w="wrap", ) @kwargs_to_strings(M="sequence", R="sequence", i="sequence_comma", o="sequence_comma") -def select(data=None, outfile=None, **kwargs): +def select( + data=None, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +) -> pd.DataFrame | np.ndarray | None: r""" Select data table subsets based on multiple spatial criteria. @@ -70,8 +78,8 @@ def select(data=None, outfile=None, **kwargs): data : str, {table-like} Pass in either a file name to an ASCII data table, a 2-D {table-classes}. - outfile : str - The file name for the output ASCII file. + {output_type} + {outfile} {area_thresh} dist2pt : str *pointfile*\|\ *lon*/*lat*\ **+d**\ *dist*. @@ -180,12 +188,12 @@ def select(data=None, outfile=None, **kwargs): Returns ------- - output : pandas.DataFrame or None - Return type depends on whether the ``outfile`` parameter is set: + ret + Return type depends on ``outfile`` and ``output_type``: - - :class:`pandas.DataFrame` table if ``outfile`` is not set. - - None if ``outfile`` is set (filtered output will be stored in file - set by ``outfile``). + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) Example ------- @@ -196,25 +204,23 @@ def select(data=None, outfile=None, **kwargs): >>> # longitudes 246 and 247 and latitudes 20 and 21 >>> out = pygmt.select(data=ship_data, region=[246, 247, 20, 21]) """ - - with GMTTempFile(suffix=".csv") as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="select", - args=build_arg_string(kwargs, infile=vintbl, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - try: - column_names = data.columns.to_list() - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) - except AttributeError: # 'str' object has no attribute 'columns' - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None - - return result + output_type = validate_output_table_type(output_type, outfile=outfile) + + column_names = None + if output_type == "pandas" and isinstance(data, pd.DataFrame): + column_names = data.columns.to_list() + + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="vector", data=data) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="select", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset( + output_type=output_type, + vfile=vouttbl, + column_names=column_names, + )