Skip to content

Commit 8a4b197

Browse files
committed
pygmt.grdhisteq: Refactor to make it easier to maintain
1 parent 5014591 commit 8a4b197

File tree

2 files changed

+61
-134
lines changed

2 files changed

+61
-134
lines changed

pygmt/src/grdhisteq.py

+61-126
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class grdhisteq: # noqa: N801
5454
@fmt_docstring
5555
@use_alias(
5656
C="divisions",
57-
D="outfile",
5857
G="outgrid",
5958
R="region",
6059
N="gaussian",
@@ -63,89 +62,7 @@ class grdhisteq: # noqa: N801
6362
h="header",
6463
)
6564
@kwargs_to_strings(R="sequence")
66-
def _grdhisteq(grid, output_type, **kwargs):
67-
r"""
68-
Perform histogram equalization for a grid.
69-
70-
Must provide ``outfile`` or ``outgrid``.
71-
72-
Full option list at :gmt-docs:`grdhisteq.html`
73-
74-
{aliases}
75-
76-
Parameters
77-
----------
78-
{grid}
79-
{outgrid}
80-
outfile : str, bool, or None
81-
The name of the output ASCII file to store the results of the
82-
histogram equalization in.
83-
output_type: str
84-
Determine the output type. Use "file", "xarray", "pandas", or
85-
"numpy".
86-
divisions : int
87-
Set the number of divisions of the data range [Default is ``16``].
88-
89-
{region}
90-
{verbose}
91-
{header}
92-
93-
Returns
94-
-------
95-
ret: pandas.DataFrame or xarray.DataArray or None
96-
Return type depends on whether the ``outgrid`` parameter is set:
97-
98-
- xarray.DataArray if ``output_type`` is "xarray""
99-
- numpy.ndarray if ``output_type`` is "numpy"
100-
- pandas.DataFrame if ``output_type`` is "pandas"
101-
- None if ``output_type`` is "file" (output is stored in
102-
``outgrid`` or ``outfile``)
103-
104-
See Also
105-
--------
106-
:func:`pygmt.grd2cpt`
107-
"""
108-
109-
with Session() as lib:
110-
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
111-
with file_context as infile:
112-
lib.call_module(
113-
module="grdhisteq", args=build_arg_string(kwargs, infile=infile)
114-
)
115-
116-
if output_type == "file":
117-
return None
118-
if output_type == "xarray":
119-
return load_dataarray(kwargs["G"])
120-
121-
result = pd.read_csv(
122-
filepath_or_buffer=kwargs["D"],
123-
sep="\t",
124-
header=None,
125-
names=["start", "stop", "bin_id"],
126-
dtype={
127-
"start": np.float32,
128-
"stop": np.float32,
129-
"bin_id": np.uint32,
130-
},
131-
)
132-
if output_type == "numpy":
133-
return result.to_numpy()
134-
135-
return result.set_index("bin_id")
136-
137-
@staticmethod
138-
@fmt_docstring
139-
def equalize_grid(
140-
grid,
141-
*,
142-
outgrid=None,
143-
divisions=None,
144-
region=None,
145-
gaussian=None,
146-
quadratic=None,
147-
verbose=None,
148-
):
65+
def equalize_grid(grid, **kwargs):
14966
r"""
15067
Perform histogram equalization for a grid.
15168
@@ -157,6 +74,8 @@ def equalize_grid(
15774
15875
Full option list at :gmt-docs:`grdhisteq.html`
15976
77+
{aliases}
78+
16079
Parameters
16180
----------
16281
{grid}
@@ -202,39 +121,31 @@ def equalize_grid(
202121
This method does a weighted histogram equalization for geographic
203122
grids to account for node area varying with latitude.
204123
"""
205-
# Return an xarray.DataArray if ``outgrid`` is not set
206124
with GMTTempFile(suffix=".nc") as tmpfile:
207-
if isinstance(outgrid, str):
208-
output_type = "file"
209-
elif outgrid is None:
210-
output_type = "xarray"
211-
outgrid = tmpfile.name
212-
else:
213-
raise GMTInvalidInput("Must specify 'outgrid' as a string or None.")
214-
return grdhisteq._grdhisteq(
215-
grid=grid,
216-
output_type=output_type,
217-
outgrid=outgrid,
218-
divisions=divisions,
219-
region=region,
220-
gaussian=gaussian,
221-
quadratic=quadratic,
222-
verbose=verbose,
223-
)
125+
with Session() as lib:
126+
with lib.virtualfile_from_data(
127+
check_kind="raster", data=grid
128+
) as vingrd:
129+
if (outgrid := kwargs.get("G")) is None:
130+
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
131+
lib.call_module(
132+
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
133+
)
134+
return load_dataarray(outgrid) if outgrid == tmpfile.name else None
224135

225136
@staticmethod
226137
@fmt_docstring
227-
def compute_bins(
228-
grid,
229-
*,
230-
output_type="pandas",
231-
outfile=None,
232-
divisions=None,
233-
quadratic=None,
234-
verbose=None,
235-
region=None,
236-
header=None,
237-
):
138+
@use_alias(
139+
C="divisions",
140+
D="outfile",
141+
R="region",
142+
N="gaussian",
143+
Q="quadratic",
144+
V="verbose",
145+
h="header",
146+
)
147+
@kwargs_to_strings(R="sequence")
148+
def compute_bins(grid, output_type="pandas", **kwargs):
238149
r"""
239150
Perform histogram equalization for a grid.
240151
@@ -254,6 +165,8 @@ def compute_bins(
254165
255166
Full option list at :gmt-docs:`grdhisteq.html`
256167
168+
{aliases}
169+
257170
Parameters
258171
----------
259172
{grid}
@@ -314,21 +227,43 @@ def compute_bins(
314227
This method does a weighted histogram equalization for geographic
315228
grids to account for node area varying with latitude.
316229
"""
230+
outfile = kwargs.get("D")
317231
output_type = validate_output_table_type(output_type, outfile=outfile)
318232

319-
if header is not None and output_type != "file":
233+
if kwargs.get("h") is not None and output_type != "file":
320234
raise GMTInvalidInput("'header' is only allowed with output_type='file'.")
321235

322236
with GMTTempFile(suffix=".txt") as tmpfile:
323-
if output_type != "file":
324-
outfile = tmpfile.name
325-
return grdhisteq._grdhisteq(
326-
grid,
327-
output_type=output_type,
328-
outfile=outfile,
329-
divisions=divisions,
330-
quadratic=quadratic,
331-
verbose=verbose,
332-
region=region,
333-
header=header,
334-
)
237+
with Session() as lib:
238+
with lib.virtualfile_from_data(
239+
check_kind="raster", data=grid
240+
) as vingrd:
241+
if outfile is None:
242+
kwargs["D"] = outfile = tmpfile.name # output to tmpfile
243+
lib.call_module(
244+
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
245+
)
246+
247+
if (
248+
outfile == tmpfile.name
249+
): # if user did not set outfile, return pd.DataFrame
250+
result = pd.read_csv(
251+
filepath_or_buffer=outfile,
252+
sep="\t",
253+
header=None,
254+
names=["start", "stop", "bin_id"],
255+
dtype={
256+
"start": np.float32,
257+
"stop": np.float32,
258+
"bin_id": np.uint32,
259+
},
260+
)
261+
elif (
262+
outfile != tmpfile.name
263+
): # return None if outfile set, output in outfile
264+
return None
265+
266+
if output_type == "numpy":
267+
return result.to_numpy()
268+
269+
return result.set_index("bin_id")

pygmt/tests/test_grdhisteq.py

-8
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,3 @@ def test_compute_bins_invalid_format(grid):
139139
grdhisteq.compute_bins(grid=grid, output_type=1)
140140
with pytest.raises(GMTInvalidInput):
141141
grdhisteq.compute_bins(grid=grid, output_type="pandas", header="o+c")
142-
143-
144-
def test_equalize_grid_invalid_format(grid):
145-
"""
146-
Test that equalize_grid fails with incorrect format.
147-
"""
148-
with pytest.raises(GMTInvalidInput):
149-
grdhisteq.equalize_grid(grid=grid, outgrid=True)

0 commit comments

Comments
 (0)