Skip to content

Commit 752305c

Browse files
authored
pygmt.grd2xyz: Improve performance by storing output in virtual files (#3097)
1 parent e3c580f commit 752305c

File tree

3 files changed

+52
-104
lines changed

3 files changed

+52
-104
lines changed

pygmt/helpers/decorators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,18 @@
254254
input and skip trailing text. **Note**: If ``incols`` is also
255255
used then the columns given to ``outcols`` correspond to the
256256
order after the ``incols`` selection has taken place.""",
257+
"outfile": """
258+
outfile
259+
File name for saving the result data. Required if ``output_type="file"``.
260+
If specified, ``output_type`` will be forced to be ``"file"``.""",
261+
"output_type": """
262+
output_type
263+
Desired output type of the result data.
264+
265+
- ``pandas`` will return a :class:`pandas.DataFrame` object.
266+
- ``numpy`` will return a :class:`numpy.ndarray` object.
267+
- ``file`` will save the result to the file specified by the ``outfile``
268+
parameter.""",
257269
"outgrid": """
258270
outgrid : str or None
259271
Name of the output netCDF grid file. For writing a specific grid

pygmt/src/grd2xyz.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,23 @@
22
grd2xyz - Convert grid to data table
33
"""
44

5+
from typing import TYPE_CHECKING, Literal
6+
57
import pandas as pd
68
import xarray as xr
79
from pygmt.clib import Session
810
from pygmt.exceptions import GMTInvalidInput
911
from pygmt.helpers import (
10-
GMTTempFile,
1112
build_arg_string,
1213
fmt_docstring,
1314
kwargs_to_strings,
1415
use_alias,
1516
validate_output_table_type,
1617
)
1718

19+
if TYPE_CHECKING:
20+
from collections.abc import Hashable
21+
1822
__doctest_skip__ = ["grd2xyz"]
1923

2024

@@ -33,7 +37,12 @@
3337
s="skiprows",
3438
)
3539
@kwargs_to_strings(R="sequence", o="sequence_comma")
36-
def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
40+
def grd2xyz(
41+
grid,
42+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
43+
outfile: str | None = None,
44+
**kwargs,
45+
) -> pd.DataFrame | xr.DataArray | None:
3746
r"""
3847
Convert grid to data table.
3948
@@ -47,15 +56,8 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
4756
Parameters
4857
----------
4958
{grid}
50-
output_type : str
51-
Determine the format the xyz data will be returned in [Default is
52-
``pandas``]:
53-
54-
- ``numpy`` - :class:`numpy.ndarray`
55-
- ``pandas``- :class:`pandas.DataFrame`
56-
- ``file`` - ASCII file (requires ``outfile``)
57-
outfile : str
58-
The file name for the output ASCII file.
59+
{output_type}
60+
{outfile}
5961
cstyle : str
6062
[**f**\|\ **i**].
6163
Replace the x- and y-coordinates on output with the corresponding
@@ -118,13 +120,12 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
118120
119121
Returns
120122
-------
121-
ret : pandas.DataFrame or numpy.ndarray or None
123+
ret
122124
Return type depends on ``outfile`` and ``output_type``:
123125
124-
- None if ``outfile`` is set (output will be stored in file set by
125-
``outfile``)
126-
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is
127-
not set (depends on ``output_type``)
126+
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
127+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
128+
(depends on ``output_type``)
128129
129130
Example
130131
-------
@@ -149,31 +150,22 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
149150
"or 'file'."
150151
)
151152

152-
# Set the default column names for the pandas dataframe header
153-
dataframe_header = ["x", "y", "z"]
153+
# Set the default column names for the pandas dataframe header.
154+
column_names: list[Hashable] = ["x", "y", "z"]
154155
# Let output pandas column names match input DataArray dimension names
155-
if isinstance(grid, xr.DataArray) and output_type == "pandas":
156+
if output_type == "pandas" and isinstance(grid, xr.DataArray):
156157
# Reverse the dims because it is rows, columns ordered.
157-
dataframe_header = [grid.dims[1], grid.dims[0], grid.name]
158-
159-
with GMTTempFile() as tmpfile:
160-
with Session() as lib:
161-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
162-
if outfile is None:
163-
outfile = tmpfile.name
164-
lib.call_module(
165-
module="grd2xyz",
166-
args=build_arg_string(kwargs, infile=vingrd, outfile=outfile),
167-
)
168-
169-
# Read temporary csv output to a pandas table
170-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
171-
result = pd.read_csv(
172-
tmpfile.name, sep="\t", names=dataframe_header, comment=">"
158+
column_names = [grid.dims[1], grid.dims[0], grid.name]
159+
160+
with Session() as lib:
161+
with (
162+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
163+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
164+
):
165+
lib.call_module(
166+
module="grd2xyz",
167+
args=build_arg_string(kwargs, infile=vingrd, outfile=vouttbl),
168+
)
169+
return lib.virtualfile_to_dataset(
170+
output_type=output_type, vfname=vouttbl, column_names=column_names
173171
)
174-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
175-
result = None
176-
177-
if output_type == "numpy":
178-
result = result.to_numpy()
179-
return result

pygmt/tests/test_grd2xyz.py

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
Test pygmt.grd2xyz.
33
"""
44

5-
from pathlib import Path
6-
75
import numpy as np
86
import pandas as pd
97
import pytest
108
from pygmt import grd2xyz
119
from pygmt.exceptions import GMTInvalidInput
12-
from pygmt.helpers import GMTTempFile
1310
from pygmt.helpers.testing import load_static_earth_relief
1411

1512

@@ -24,70 +21,17 @@ def fixture_grid():
2421
@pytest.mark.benchmark
2522
def test_grd2xyz(grid):
2623
"""
27-
Make sure grd2xyz works as expected.
28-
"""
29-
xyz_data = grd2xyz(grid=grid, output_type="numpy")
30-
assert xyz_data.shape == (112, 3)
31-
32-
33-
def test_grd2xyz_format(grid):
24+
Test the basic functionality of grd2xyz.
3425
"""
35-
Test that correct formats are returned.
36-
"""
37-
lon = -50.5
38-
lat = -18.5
39-
orig_val = grid.sel(lon=lon, lat=lat).to_numpy()
40-
xyz_default = grd2xyz(grid=grid)
41-
xyz_val = xyz_default[(xyz_default["lon"] == lon) & (xyz_default["lat"] == lat)][
42-
"z"
43-
].to_numpy()
44-
assert isinstance(xyz_default, pd.DataFrame)
45-
assert orig_val.size == 1
46-
assert xyz_val.size == 1
47-
np.testing.assert_allclose(orig_val, xyz_val)
48-
xyz_array = grd2xyz(grid=grid, output_type="numpy")
49-
assert isinstance(xyz_array, np.ndarray)
50-
xyz_df = grd2xyz(grid=grid, output_type="pandas", outcols=None)
26+
xyz_df = grd2xyz(grid=grid)
5127
assert isinstance(xyz_df, pd.DataFrame)
5228
assert list(xyz_df.columns) == ["lon", "lat", "z"]
29+
assert xyz_df.shape == (112, 3)
5330

54-
55-
def test_grd2xyz_file_output(grid):
56-
"""
57-
Test that grd2xyz returns a file output when it is specified.
58-
"""
59-
with GMTTempFile(suffix=".xyz") as tmpfile:
60-
result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="file")
61-
assert result is None # return value is None
62-
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists
63-
64-
65-
def test_grd2xyz_invalid_format(grid):
66-
"""
67-
Test that grd2xyz fails with incorrect format.
68-
"""
69-
with pytest.raises(GMTInvalidInput):
70-
grd2xyz(grid=grid, output_type=1)
71-
72-
73-
def test_grd2xyz_no_outfile(grid):
74-
"""
75-
Test that grd2xyz fails when a string output is set with no outfile.
76-
"""
77-
with pytest.raises(GMTInvalidInput):
78-
grd2xyz(grid=grid, output_type="file")
79-
80-
81-
def test_grd2xyz_outfile_incorrect_output_type(grid):
82-
"""
83-
Test that grd2xyz raises a warning when an outfile filename is set but the
84-
output_type is not set to 'file'.
85-
"""
86-
with pytest.warns(RuntimeWarning):
87-
with GMTTempFile(suffix=".xyz") as tmpfile:
88-
result = grd2xyz(grid=grid, outfile=tmpfile.name, output_type="numpy")
89-
assert result is None # return value is None
90-
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists
31+
lon, lat = -50.5, -18.5
32+
orig_val = grid.sel(lon=lon, lat=lat).to_numpy()
33+
xyz_val = xyz_df[(xyz_df["lon"] == lon) & (xyz_df["lat"] == lat)]["z"].to_numpy()
34+
np.testing.assert_allclose(orig_val, xyz_val)
9135

9236

9337
def test_grd2xyz_pandas_output_with_o(grid):

0 commit comments

Comments
 (0)