Skip to content

Commit ce73717

Browse files
seismanweiji14
andauthored
grdtrack: Fix the bug when profile is given (#1867)
Co-authored-by: Wei Ji <[email protected]>
1 parent 31ccd25 commit ce73717

File tree

2 files changed

+112
-15
lines changed

2 files changed

+112
-15
lines changed

pygmt/src/grdtrack.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""
22
grdtrack - Sample grids at specified (x,y) locations.
33
"""
4+
import warnings
5+
46
import pandas as pd
7+
import xarray as xr
58
from pygmt.clib import Session
69
from pygmt.exceptions import GMTInvalidInput
710
from pygmt.helpers import (
@@ -11,6 +14,7 @@
1114
kwargs_to_strings,
1215
use_alias,
1316
)
17+
from pygmt.src.which import which
1418

1519
__doctest_skip__ = ["grdtrack"]
1620

@@ -43,7 +47,7 @@
4347
w="wrap",
4448
)
4549
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
46-
def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
50+
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
4751
r"""
4852
Sample grids at specified (x,y) locations.
4953
@@ -67,14 +71,14 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
6771
6872
Parameters
6973
----------
70-
points : str or {table-like}
71-
Pass in either a file name to an ASCII data table, a 2D
72-
{table-classes}.
73-
7474
grid : xarray.DataArray or str
7575
Gridded array from which to sample values from, or a filename (netcdf
7676
format).
7777
78+
points : str or {table-like}
79+
Pass in either a file name to an ASCII data table, a 2D
80+
{table-classes}.
81+
7882
newcolname : str
7983
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
8084
new column in the track :class:`pandas.DataFrame` table where the
@@ -283,26 +287,65 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
283287
... points=points, grid=grid, newcolname="bathymetry"
284288
... )
285289
"""
290+
# pylint: disable=too-many-branches
291+
if points is not None and kwargs.get("E") is not None:
292+
raise GMTInvalidInput("Can't set both 'points' and 'profile'.")
293+
294+
if points is None and kwargs.get("E") is None:
295+
raise GMTInvalidInput("Must give 'points' or set 'profile'.")
296+
286297
if hasattr(points, "columns") and newcolname is None:
287298
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
288299

300+
# Backward compatibility with old parameter order "points, grid".
301+
# deprecated_version="0.7.0", remove_version="v0.9.0"
302+
is_a_grid = True
303+
if not isinstance(grid, (xr.DataArray, str)):
304+
is_a_grid = False
305+
elif isinstance(grid, str):
306+
try:
307+
xr.open_dataarray(which(grid, download="a"), engine="netcdf4").close()
308+
is_a_grid = True
309+
except (ValueError, OSError):
310+
is_a_grid = False
311+
if not is_a_grid:
312+
msg = (
313+
"Positional parameters 'points, grid' of pygmt.grdtrack() has changed "
314+
"to 'grid, points=None' since v0.7.0. It's likely that you're NOT "
315+
"passing a valid grid as the first positional argument or "
316+
"are passing an invalid grid to the 'grid' parameter. "
317+
"Please check the order of arguments with the latest documentation. "
318+
"This warning will be removed in v0.9.0."
319+
)
320+
grid, points = points, grid
321+
warnings.warn(msg, category=FutureWarning, stacklevel=1)
322+
289323
with GMTTempFile(suffix=".csv") as tmpfile:
290324
with Session() as lib:
291-
# Choose how data will be passed into the module
292-
table_context = lib.virtualfile_from_data(check_kind="vector", data=points)
293325
# Store the xarray.DataArray grid in virtualfile
294326
grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
295327

296-
# Run grdtrack on the temporary (csv) points table
297-
# and (netcdf) grid virtualfile
298-
with table_context as csvfile:
299-
with grid_context as grdfile:
300-
kwargs.update({"G": grdfile})
301-
if outfile is None: # Output to tmpfile if outfile is not set
302-
outfile = tmpfile.name
328+
with grid_context as grdfile:
329+
kwargs.update({"G": grdfile})
330+
if outfile is None: # Output to tmpfile if outfile is not set
331+
outfile = tmpfile.name
332+
333+
if points is not None:
334+
# Choose how data will be passed into the module
335+
table_context = lib.virtualfile_from_data(
336+
check_kind="vector", data=points
337+
)
338+
with table_context as csvfile:
339+
lib.call_module(
340+
module="grdtrack",
341+
args=build_arg_string(
342+
kwargs, infile=csvfile, outfile=outfile
343+
),
344+
)
345+
else:
303346
lib.call_module(
304347
module="grdtrack",
305-
args=build_arg_string(kwargs, infile=csvfile, outfile=outfile),
348+
args=build_arg_string(kwargs, outfile=outfile),
306349
)
307350

308351
# Read temporary csv output to a pandas table

pygmt/tests/test_grdtrack.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,29 @@ def test_grdtrack_input_csvfile_and_ncfile_to_dataframe(expected_array):
9999
npt.assert_allclose(np.array(output), expected_array)
100100

101101

102+
def test_grdtrack_profile(dataarray):
103+
"""
104+
Run grdtrack by passing a profile.
105+
"""
106+
output = grdtrack(grid=dataarray, profile="-51/-17/-54/-19")
107+
assert isinstance(output, pd.DataFrame)
108+
npt.assert_allclose(
109+
np.array(output),
110+
np.array(
111+
[
112+
[-51.0, -17.0, 669.671875],
113+
[-51.42430204, -17.28838525, 847.40745877],
114+
[-51.85009439, -17.57598444, 885.30534844],
115+
[-52.27733766, -17.86273467, 829.85423488],
116+
[-52.70599151, -18.14857333, 776.83702212],
117+
[-53.13601473, -18.43343819, 631.07867839],
118+
[-53.56736521, -18.7172675, 504.28037216],
119+
[-54.0, -19.0, 486.10351562],
120+
]
121+
),
122+
)
123+
124+
102125
def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
103126
"""
104127
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or
@@ -137,3 +160,34 @@ def test_grdtrack_without_outfile_setting(dataarray, dataframe):
137160
"""
138161
with pytest.raises(GMTInvalidInput):
139162
grdtrack(points=dataframe, grid=dataarray)
163+
164+
165+
def test_grdtrack_no_points_and_profile(dataarray):
166+
"""
167+
Run grdtrack but don't set 'points' and 'profile'.
168+
"""
169+
with pytest.raises(GMTInvalidInput):
170+
grdtrack(grid=dataarray)
171+
172+
173+
def test_grdtrack_set_points_and_profile(dataarray, dataframe):
174+
"""
175+
Run grdtrack but set both 'points' and 'profile'.
176+
"""
177+
with pytest.raises(GMTInvalidInput):
178+
grdtrack(grid=dataarray, points=dataframe, profile="BL/TR")
179+
180+
181+
def test_grdtrack_old_parameter_order(dataframe, dataarray, expected_array):
182+
"""
183+
Run grdtrack with the old parameter order 'points, grid'.
184+
185+
This test should be removed in v0.9.0.
186+
"""
187+
for points in (POINTS_DATA, dataframe):
188+
for grid in ("@static_earth_relief.nc", dataarray):
189+
with pytest.warns(expected_warning=FutureWarning) as record:
190+
output = grdtrack(points, grid)
191+
assert len(record) == 1
192+
assert isinstance(output, pd.DataFrame)
193+
npt.assert_allclose(np.array(output), expected_array)

0 commit comments

Comments
 (0)