Skip to content

Commit

Permalink
pygmt.grdhisteq: Refactor to make it easier to maintain
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Feb 28, 2024
1 parent 5014591 commit 5cf2375
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 134 deletions.
185 changes: 59 additions & 126 deletions pygmt/src/grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class grdhisteq: # noqa: N801
@fmt_docstring
@use_alias(
C="divisions",
D="outfile",
G="outgrid",
R="region",
N="gaussian",
Expand All @@ -63,89 +62,7 @@ class grdhisteq: # noqa: N801
h="header",
)
@kwargs_to_strings(R="sequence")
def _grdhisteq(grid, output_type, **kwargs):
r"""
Perform histogram equalization for a grid.
Must provide ``outfile`` or ``outgrid``.
Full option list at :gmt-docs:`grdhisteq.html`
{aliases}
Parameters
----------
{grid}
{outgrid}
outfile : str, bool, or None
The name of the output ASCII file to store the results of the
histogram equalization in.
output_type: str
Determine the output type. Use "file", "xarray", "pandas", or
"numpy".
divisions : int
Set the number of divisions of the data range [Default is ``16``].
{region}
{verbose}
{header}
Returns
-------
ret: pandas.DataFrame or xarray.DataArray or None
Return type depends on whether the ``outgrid`` parameter is set:
- xarray.DataArray if ``output_type`` is "xarray""
- numpy.ndarray if ``output_type`` is "numpy"
- pandas.DataFrame if ``output_type`` is "pandas"
- None if ``output_type`` is "file" (output is stored in
``outgrid`` or ``outfile``)
See Also
--------
:func:`pygmt.grd2cpt`
"""

with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=infile)
)

if output_type == "file":
return None
if output_type == "xarray":
return load_dataarray(kwargs["G"])

result = pd.read_csv(
filepath_or_buffer=kwargs["D"],
sep="\t",
header=None,
names=["start", "stop", "bin_id"],
dtype={
"start": np.float32,
"stop": np.float32,
"bin_id": np.uint32,
},
)
if output_type == "numpy":
return result.to_numpy()

return result.set_index("bin_id")

@staticmethod
@fmt_docstring
def equalize_grid(
grid,
*,
outgrid=None,
divisions=None,
region=None,
gaussian=None,
quadratic=None,
verbose=None,
):
def equalize_grid(grid, **kwargs):
r"""
Perform histogram equalization for a grid.
Expand All @@ -157,6 +74,8 @@ def equalize_grid(
Full option list at :gmt-docs:`grdhisteq.html`
{aliases}
Parameters
----------
{grid}
Expand Down Expand Up @@ -202,39 +121,31 @@ def equalize_grid(
This method does a weighted histogram equalization for geographic
grids to account for node area varying with latitude.
"""
# Return an xarray.DataArray if ``outgrid`` is not set
with GMTTempFile(suffix=".nc") as tmpfile:
if isinstance(outgrid, str):
output_type = "file"
elif outgrid is None:
output_type = "xarray"
outgrid = tmpfile.name
else:
raise GMTInvalidInput("Must specify 'outgrid' as a string or None.")
return grdhisteq._grdhisteq(
grid=grid,
output_type=output_type,
outgrid=outgrid,
divisions=divisions,
region=region,
gaussian=gaussian,
quadratic=quadratic,
verbose=verbose,
)
with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as vingrd:
if (outgrid := kwargs.get("G")) is None:
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
)
return load_dataarray(outgrid) if outgrid == tmpfile.name else None

@staticmethod
@fmt_docstring
def compute_bins(
grid,
*,
output_type="pandas",
outfile=None,
divisions=None,
quadratic=None,
verbose=None,
region=None,
header=None,
):
@use_alias(
C="divisions",
D="outfile",
R="region",
N="gaussian",
Q="quadratic",
V="verbose",
h="header",
)
@kwargs_to_strings(R="sequence")
def compute_bins(grid, output_type="pandas", **kwargs):
r"""
Perform histogram equalization for a grid.
Expand All @@ -254,6 +165,8 @@ def compute_bins(
Full option list at :gmt-docs:`grdhisteq.html`
{aliases}
Parameters
----------
{grid}
Expand Down Expand Up @@ -314,21 +227,41 @@ def compute_bins(
This method does a weighted histogram equalization for geographic
grids to account for node area varying with latitude.
"""
outfile = kwargs.get("D")
output_type = validate_output_table_type(output_type, outfile=outfile)

if header is not None and output_type != "file":
if kwargs.get("h") is not None and output_type != "file":
raise GMTInvalidInput("'header' is only allowed with output_type='file'.")

with GMTTempFile(suffix=".txt") as tmpfile:
if output_type != "file":
outfile = tmpfile.name
return grdhisteq._grdhisteq(
grid,
output_type=output_type,
outfile=outfile,
divisions=divisions,
quadratic=quadratic,
verbose=verbose,
region=region,
header=header,
)
with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as vingrd:
if outfile is None:
kwargs["D"] = outfile = tmpfile.name # output to tmpfile
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
)

if outfile == tmpfile.name:
# if user did not set outfile, return pd.DataFrame
result = pd.read_csv(
filepath_or_buffer=outfile,
sep="\t",
header=None,
names=["start", "stop", "bin_id"],
dtype={
"start": np.float32,
"stop": np.float32,
"bin_id": np.uint32,
},
)
elif outfile != tmpfile.name:
# return None if outfile set, output in outfile
return None

if output_type == "numpy":
return result.to_numpy()

return result.set_index("bin_id")
8 changes: 0 additions & 8 deletions pygmt/tests/test_grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,3 @@ def test_compute_bins_invalid_format(grid):
grdhisteq.compute_bins(grid=grid, output_type=1)
with pytest.raises(GMTInvalidInput):
grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c")


def test_equalize_grid_invalid_format(grid):
"""
Test that equalize_grid fails with incorrect format.
"""
with pytest.raises(GMTInvalidInput):
grdhisteq.equalize_grid(grid=grid, outgrid=True)

0 comments on commit 5cf2375

Please sign in to comment.