diff --git a/pygmt/src/grdvolume.py b/pygmt/src/grdvolume.py index 1bb696e9e04..eaa70eacf61 100644 --- a/pygmt/src/grdvolume.py +++ b/pygmt/src/grdvolume.py @@ -2,10 +2,12 @@ grdvolume - Calculate grid volume and area constrained by a contour. """ +from typing import Literal + import pandas as pd +import xarray as xr from pygmt.clib import Session from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, kwargs_to_strings, @@ -24,7 +26,12 @@ V="verbose", ) @kwargs_to_strings(C="sequence", R="sequence") -def grdvolume(grid, output_type="pandas", outfile=None, **kwargs): +def grdvolume( + grid, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +) -> pd.DataFrame | xr.DataArray | None: r""" Determine the volume between the surface of a grid and a plane. @@ -41,15 +48,8 @@ def grdvolume(grid, output_type="pandas", outfile=None, **kwargs): Parameters ---------- {grid} - output_type : str - Determine the format the output data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) - outfile : str - The file name for the output ASCII file. + {output_type} + {outfile} contour : str, float, or list *cval*\|\ *low/high/delta*\|\ **r**\ *low/high*\|\ **r**\ *cval*. Find area, volume and mean height (volume/area) inside and above the @@ -69,14 +69,12 @@ def grdvolume(grid, output_type="pandas", outfile=None, **kwargs): Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - - None if ``outfile`` is set (output will be stored in file set by - ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` - is not set (depends on ``output_type`` [Default is - :class:`pandas.DataFrame`]) + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) Example ------- @@ -103,22 +101,13 @@ def grdvolume(grid, output_type="pandas", outfile=None, **kwargs): """ output_type = validate_output_table_type(output_type, outfile=outfile) - with GMTTempFile() as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="grdvolume", - args=build_arg_string(kwargs, infile=vingrd, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None - - if output_type == "numpy": - result = result.to_numpy() - return result + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="grdvolume", + args=build_arg_string(kwargs, infile=vingrd, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl) diff --git a/pygmt/tests/test_grdvolume.py b/pygmt/tests/test_grdvolume.py index b56e76b1f19..5ac803490d5 100644 --- a/pygmt/tests/test_grdvolume.py +++ b/pygmt/tests/test_grdvolume.py @@ -2,15 +2,11 @@ Test pygmt.grdvolume. """ -from pathlib import Path - import numpy as np import numpy.testing as npt import pandas as pd import pytest from pygmt import grdvolume -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -22,14 +18,6 @@ def fixture_grid(): return load_static_earth_relief() -@pytest.fixture(scope="module", name="region") -def fixture_region(): - """ - Set the data region for the tests. - """ - return [-53, -50, -22, -20] - - @pytest.fixture(scope="module", name="data") def fixture_data(): """ @@ -47,57 +35,16 @@ def fixture_data(): return data -def test_grdvolume_format(grid, region): - """ - Test that correct formats are returned. - """ - grdvolume_default = grdvolume(grid=grid, region=region) - assert isinstance(grdvolume_default, pd.DataFrame) - grdvolume_array = grdvolume(grid=grid, output_type="numpy", region=region) - assert isinstance(grdvolume_array, np.ndarray) - grdvolume_df = grdvolume(grid=grid, output_type="pandas", region=region) - assert isinstance(grdvolume_df, pd.DataFrame) - - -def test_grdvolume_invalid_format(grid): - """ - Test that grdvolume fails with incorrect output_type argument. - """ - with pytest.raises(GMTInvalidInput): - grdvolume(grid=grid, output_type=1) - - -def test_grdvolume_no_outfile(grid): - """ - Test that grdvolume fails when output_type set to 'file' but no outfile is - specified. - """ - with pytest.raises(GMTInvalidInput): - grdvolume(grid=grid, output_type="file") - - @pytest.mark.benchmark -def test_grdvolume_no_outgrid(grid, data, region): +def test_grdvolume(grid, data): """ - Test the expected output of grdvolume with no output file set. + Test the basic functionality of grdvolume. """ test_output = grdvolume( - grid=grid, contour=[200, 400, 50], output_type="numpy", region=region + grid=grid, + contour=[200, 400, 50], + output_type="numpy", + region=[-53, -50, -22, -20], ) + assert isinstance(test_output, pd.DataFrame) npt.assert_allclose(test_output, data) - - -def test_grdvolume_outgrid(grid, region): - """ - Test the expected output of grdvolume with an output file set. - """ - with GMTTempFile(suffix=".csv") as tmpfile: - result = grdvolume( - grid=grid, - contour=[200, 400, 50], - output_type="file", - outfile=tmpfile.name, - region=region, - ) - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists