diff --git a/pygmt/src/grdhisteq.py b/pygmt/src/grdhisteq.py index c9524aa1ab4..df66a9cf78e 100644 --- a/pygmt/src/grdhisteq.py +++ b/pygmt/src/grdhisteq.py @@ -54,7 +54,6 @@ class grdhisteq: # noqa: N801 @fmt_docstring @use_alias( C="divisions", - D="outfile", G="outgrid", R="region", N="gaussian", @@ -63,89 +62,7 @@ class grdhisteq: # noqa: N801 h="header", ) @kwargs_to_strings(R="sequence") - def _grdhisteq(grid, output_type, **kwargs): - r""" - Perform histogram equalization for a grid. - - Must provide ``outfile`` or ``outgrid``. - - Full option list at :gmt-docs:`grdhisteq.html` - - {aliases} - - Parameters - ---------- - {grid} - {outgrid} - outfile : str, bool, or None - The name of the output ASCII file to store the results of the - histogram equalization in. - output_type: str - Determine the output type. Use "file", "xarray", "pandas", or - "numpy". - divisions : int - Set the number of divisions of the data range [Default is ``16``]. - - {region} - {verbose} - {header} - - Returns - ------- - ret: pandas.DataFrame or xarray.DataArray or None - Return type depends on whether the ``outgrid`` parameter is set: - - - xarray.DataArray if ``output_type`` is "xarray"" - - numpy.ndarray if ``output_type`` is "numpy" - - pandas.DataFrame if ``output_type`` is "pandas" - - None if ``output_type`` is "file" (output is stored in - ``outgrid`` or ``outfile``) - - See Also - -------- - :func:`pygmt.grd2cpt` - """ - - with Session() as lib: - file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) - with file_context as infile: - lib.call_module( - module="grdhisteq", args=build_arg_string(kwargs, infile=infile) - ) - - if output_type == "file": - return None - if output_type == "xarray": - return load_dataarray(kwargs["G"]) - - result = pd.read_csv( - filepath_or_buffer=kwargs["D"], - sep="\t", - header=None, - names=["start", "stop", "bin_id"], - dtype={ - "start": np.float32, - "stop": np.float32, - "bin_id": np.uint32, - }, - ) - if output_type == "numpy": - return result.to_numpy() - - return result.set_index("bin_id") - - @staticmethod - @fmt_docstring - def equalize_grid( - grid, - *, - outgrid=None, - divisions=None, - region=None, - gaussian=None, - quadratic=None, - verbose=None, - ): + def equalize_grid(grid, **kwargs): r""" Perform histogram equalization for a grid. @@ -157,6 +74,8 @@ def equalize_grid( Full option list at :gmt-docs:`grdhisteq.html` + {aliases} + Parameters ---------- {grid} @@ -202,39 +121,31 @@ def equalize_grid( This method does a weighted histogram equalization for geographic grids to account for node area varying with latitude. """ - # Return an xarray.DataArray if ``outgrid`` is not set with GMTTempFile(suffix=".nc") as tmpfile: - if isinstance(outgrid, str): - output_type = "file" - elif outgrid is None: - output_type = "xarray" - outgrid = tmpfile.name - else: - raise GMTInvalidInput("Must specify 'outgrid' as a string or None.") - return grdhisteq._grdhisteq( - grid=grid, - output_type=output_type, - outgrid=outgrid, - divisions=divisions, - region=region, - gaussian=gaussian, - quadratic=quadratic, - verbose=verbose, - ) + with Session() as lib: + with lib.virtualfile_from_data( + check_kind="raster", data=grid + ) as vingrd: + if (outgrid := kwargs.get("G")) is None: + kwargs["G"] = outgrid = tmpfile.name # output to tmpfile + lib.call_module( + module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd) + ) + return load_dataarray(outgrid) if outgrid == tmpfile.name else None @staticmethod @fmt_docstring - def compute_bins( - grid, - *, - output_type="pandas", - outfile=None, - divisions=None, - quadratic=None, - verbose=None, - region=None, - header=None, - ): + @use_alias( + C="divisions", + D="outfile", + R="region", + N="gaussian", + Q="quadratic", + V="verbose", + h="header", + ) + @kwargs_to_strings(R="sequence") + def compute_bins(grid, output_type="pandas", **kwargs): r""" Perform histogram equalization for a grid. @@ -254,6 +165,8 @@ def compute_bins( Full option list at :gmt-docs:`grdhisteq.html` + {aliases} + Parameters ---------- {grid} @@ -314,21 +227,41 @@ def compute_bins( This method does a weighted histogram equalization for geographic grids to account for node area varying with latitude. """ + outfile = kwargs.get("D") output_type = validate_output_table_type(output_type, outfile=outfile) - if header is not None and output_type != "file": + if kwargs.get("h") is not None and output_type != "file": raise GMTInvalidInput("'header' is only allowed with output_type='file'.") with GMTTempFile(suffix=".txt") as tmpfile: - if output_type != "file": - outfile = tmpfile.name - return grdhisteq._grdhisteq( - grid, - output_type=output_type, - outfile=outfile, - divisions=divisions, - quadratic=quadratic, - verbose=verbose, - region=region, - header=header, - ) + with Session() as lib: + with lib.virtualfile_from_data( + check_kind="raster", data=grid + ) as vingrd: + if outfile is None: + kwargs["D"] = outfile = tmpfile.name # output to tmpfile + lib.call_module( + module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd) + ) + + if outfile == tmpfile.name: + # if user did not set outfile, return pd.DataFrame + result = pd.read_csv( + filepath_or_buffer=outfile, + sep="\t", + header=None, + names=["start", "stop", "bin_id"], + dtype={ + "start": np.float32, + "stop": np.float32, + "bin_id": np.uint32, + }, + ) + elif outfile != tmpfile.name: + # return None if outfile set, output in outfile + return None + + if output_type == "numpy": + return result.to_numpy() + + return result.set_index("bin_id") diff --git a/pygmt/tests/test_grdhisteq.py b/pygmt/tests/test_grdhisteq.py index e2fb5b41d18..a0a55d01d83 100644 --- a/pygmt/tests/test_grdhisteq.py +++ b/pygmt/tests/test_grdhisteq.py @@ -139,11 +139,3 @@ def test_compute_bins_invalid_format(grid): grdhisteq.compute_bins(grid=grid, output_type=1) with pytest.raises(GMTInvalidInput): grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c") - - -def test_equalize_grid_invalid_format(grid): - """ - Test that equalize_grid fails with incorrect format. - """ - with pytest.raises(GMTInvalidInput): - grdhisteq.equalize_grid(grid=grid, outgrid=True)