Skip to content

pygmt.grdhisteq: Refactor to use codes consistent with other wrappers #3076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 59 additions & 126 deletions pygmt/src/grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class grdhisteq: # noqa: N801
@fmt_docstring
@use_alias(
C="divisions",
D="outfile",
G="outgrid",
R="region",
N="gaussian",
Expand All @@ -63,89 +62,7 @@ class grdhisteq: # noqa: N801
h="header",
)
@kwargs_to_strings(R="sequence")
def _grdhisteq(grid, output_type, **kwargs):
r"""
Perform histogram equalization for a grid.

Must provide ``outfile`` or ``outgrid``.

Full option list at :gmt-docs:`grdhisteq.html`

{aliases}

Parameters
----------
{grid}
{outgrid}
outfile : str, bool, or None
The name of the output ASCII file to store the results of the
histogram equalization in.
output_type: str
Determine the output type. Use "file", "xarray", "pandas", or
"numpy".
divisions : int
Set the number of divisions of the data range [Default is ``16``].

{region}
{verbose}
{header}

Returns
-------
ret: pandas.DataFrame or xarray.DataArray or None
Return type depends on whether the ``outgrid`` parameter is set:

- xarray.DataArray if ``output_type`` is "xarray""
- numpy.ndarray if ``output_type`` is "numpy"
- pandas.DataFrame if ``output_type`` is "pandas"
- None if ``output_type`` is "file" (output is stored in
``outgrid`` or ``outfile``)

See Also
--------
:func:`pygmt.grd2cpt`
"""

with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=infile)
)

if output_type == "file":
return None
if output_type == "xarray":
return load_dataarray(kwargs["G"])

result = pd.read_csv(
filepath_or_buffer=kwargs["D"],
sep="\t",
header=None,
names=["start", "stop", "bin_id"],
dtype={
"start": np.float32,
"stop": np.float32,
"bin_id": np.uint32,
},
)
if output_type == "numpy":
return result.to_numpy()

return result.set_index("bin_id")

@staticmethod
@fmt_docstring
def equalize_grid(
grid,
*,
outgrid=None,
divisions=None,
region=None,
gaussian=None,
quadratic=None,
verbose=None,
):
def equalize_grid(grid, **kwargs):
r"""
Perform histogram equalization for a grid.

Expand All @@ -157,6 +74,8 @@ def equalize_grid(

Full option list at :gmt-docs:`grdhisteq.html`

{aliases}

Parameters
----------
{grid}
Expand Down Expand Up @@ -202,39 +121,31 @@ def equalize_grid(
This method does a weighted histogram equalization for geographic
grids to account for node area varying with latitude.
"""
# Return an xarray.DataArray if ``outgrid`` is not set
with GMTTempFile(suffix=".nc") as tmpfile:
if isinstance(outgrid, str):
output_type = "file"
elif outgrid is None:
output_type = "xarray"
outgrid = tmpfile.name
else:
raise GMTInvalidInput("Must specify 'outgrid' as a string or None.")
return grdhisteq._grdhisteq(
grid=grid,
output_type=output_type,
outgrid=outgrid,
divisions=divisions,
region=region,
gaussian=gaussian,
quadratic=quadratic,
verbose=verbose,
)
with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as vingrd:
if (outgrid := kwargs.get("G")) is None:
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
)
return load_dataarray(outgrid) if outgrid == tmpfile.name else None

@staticmethod
@fmt_docstring
def compute_bins(
grid,
*,
output_type="pandas",
outfile=None,
divisions=None,
quadratic=None,
verbose=None,
region=None,
header=None,
):
@use_alias(
C="divisions",
D="outfile",
R="region",
N="gaussian",
Q="quadratic",
V="verbose",
h="header",
)
@kwargs_to_strings(R="sequence")
def compute_bins(grid, output_type="pandas", **kwargs):
r"""
Perform histogram equalization for a grid.

Expand All @@ -254,6 +165,8 @@ def compute_bins(

Full option list at :gmt-docs:`grdhisteq.html`

{aliases}

Parameters
----------
{grid}
Expand Down Expand Up @@ -314,21 +227,41 @@ def compute_bins(
This method does a weighted histogram equalization for geographic
grids to account for node area varying with latitude.
"""
outfile = kwargs.get("D")
output_type = validate_output_table_type(output_type, outfile=outfile)

if header is not None and output_type != "file":
if kwargs.get("h") is not None and output_type != "file":
raise GMTInvalidInput("'header' is only allowed with output_type='file'.")

with GMTTempFile(suffix=".txt") as tmpfile:
if output_type != "file":
outfile = tmpfile.name
return grdhisteq._grdhisteq(
grid,
output_type=output_type,
outfile=outfile,
divisions=divisions,
quadratic=quadratic,
verbose=verbose,
region=region,
header=header,
)
with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as vingrd:
if outfile is None:
kwargs["D"] = outfile = tmpfile.name # output to tmpfile
lib.call_module(
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
)

if outfile == tmpfile.name:
# if user did not set outfile, return pd.DataFrame
result = pd.read_csv(
filepath_or_buffer=outfile,
sep="\t",
header=None,
names=["start", "stop", "bin_id"],
dtype={
"start": np.float32,
"stop": np.float32,
"bin_id": np.uint32,
},
)
elif outfile != tmpfile.name:
# return None if outfile set, output in outfile
return None

if output_type == "numpy":
return result.to_numpy()

return result.set_index("bin_id")
8 changes: 0 additions & 8 deletions pygmt/tests/test_grdhisteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,3 @@ def test_compute_bins_invalid_format(grid):
grdhisteq.compute_bins(grid=grid, output_type=1)
with pytest.raises(GMTInvalidInput):
grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c")


def test_equalize_grid_invalid_format(grid):
"""
Test that equalize_grid fails with incorrect format.
"""
with pytest.raises(GMTInvalidInput):
grdhisteq.equalize_grid(grid=grid, outgrid=True)