Skip to content

Commit 24b967a

Browse files
committed
Make all functions/methods have consistent behavior for table output
1 parent 193bd05 commit 24b967a

10 files changed

+207
-229
lines changed

pygmt/src/blockm.py

+56-31
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66
import pandas as pd
77
from pygmt.clib import Session
88
from pygmt.helpers import (
9-
GMTTempFile,
109
build_arg_string,
1110
fmt_docstring,
1211
kwargs_to_strings,
1312
use_alias,
13+
validate_output_table_type,
1414
)
1515

1616
__doctest_skip__ = ["blockmean", "blockmedian", "blockmode"]
1717

1818

19-
def _blockm(block_method, data, x, y, z, outfile, **kwargs):
19+
def _blockm(block_method, data, x, y, z, output_type, outfile, **kwargs):
2020
r"""
2121
Block average (x, y, z) data tables by mean, median, or mode estimation.
2222
@@ -42,30 +42,28 @@ def _blockm(block_method, data, x, y, z, outfile, **kwargs):
4242
- None if ``outfile`` is set (filtered output will be stored in file
4343
set by ``outfile``)
4444
"""
45-
with GMTTempFile(suffix=".csv") as tmpfile:
46-
with Session() as lib:
47-
with lib.virtualfile_in(
45+
output_type = validate_output_table_type(output_type, outfile=outfile)
46+
47+
with Session() as lib:
48+
with (
49+
lib.virtualfile_in(
4850
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
49-
) as vintbl:
50-
# Run blockm* on data table
51-
if outfile is None:
52-
outfile = tmpfile.name
53-
lib.call_module(
54-
module=block_method,
55-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
56-
)
57-
58-
# Read temporary csv output to a pandas table
59-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
60-
try:
61-
column_names = data.columns.to_list()
62-
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
63-
except AttributeError: # 'str' object has no attribute 'columns'
64-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
65-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
66-
result = None
67-
68-
return result
51+
) as vintbl,
52+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
53+
):
54+
lib.call_module(
55+
module=block_method,
56+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
57+
)
58+
column_names = None
59+
if isinstance(data, pd.DataFrame):
60+
column_names = data.columns.to_list()
61+
62+
return lib.return_table(
63+
output_type=output_type,
64+
vfile=vouttbl,
65+
column_names=column_names,
66+
)
6967

7068

7169
@fmt_docstring
@@ -86,7 +84,9 @@ def _blockm(block_method, data, x, y, z, outfile, **kwargs):
8684
w="wrap",
8785
)
8886
@kwargs_to_strings(I="sequence", R="sequence", i="sequence_comma", o="sequence_comma")
89-
def blockmean(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
87+
def blockmean(
88+
data=None, x=None, y=None, z=None, output_type="pandas", outfile=None, **kwargs
89+
):
9090
r"""
9191
Block average (x, y, z) data tables by mean estimation.
9292
@@ -159,7 +159,14 @@ def blockmean(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
159159
>>> data_bmean = pygmt.blockmean(data=data, region=[245, 255, 20, 30], spacing="5m")
160160
"""
161161
return _blockm(
162-
block_method="blockmean", data=data, x=x, y=y, z=z, outfile=outfile, **kwargs
162+
block_method="blockmean",
163+
data=data,
164+
x=x,
165+
y=y,
166+
z=z,
167+
output_type=output_type,
168+
outfile=outfile,
169+
**kwargs,
163170
)
164171

165172

@@ -180,7 +187,9 @@ def blockmean(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
180187
w="wrap",
181188
)
182189
@kwargs_to_strings(I="sequence", R="sequence", i="sequence_comma", o="sequence_comma")
183-
def blockmedian(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
190+
def blockmedian(
191+
data=None, x=None, y=None, z=None, output_type="pandas", outfile=None, **kwargs
192+
):
184193
r"""
185194
Block average (x, y, z) data tables by median estimation.
186195
@@ -246,7 +255,14 @@ def blockmedian(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
246255
... )
247256
"""
248257
return _blockm(
249-
block_method="blockmedian", data=data, x=x, y=y, z=z, outfile=outfile, **kwargs
258+
block_method="blockmedian",
259+
data=data,
260+
x=x,
261+
y=y,
262+
z=z,
263+
output_type=output_type,
264+
outfile=outfile,
265+
**kwargs,
250266
)
251267

252268

@@ -267,7 +283,9 @@ def blockmedian(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
267283
w="wrap",
268284
)
269285
@kwargs_to_strings(I="sequence", R="sequence", i="sequence_comma", o="sequence_comma")
270-
def blockmode(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
286+
def blockmode(
287+
data=None, x=None, y=None, z=None, output_type="pandas", outfile=None, **kwargs
288+
):
271289
r"""
272290
Block average (x, y, z) data tables by mode estimation.
273291
@@ -331,5 +349,12 @@ def blockmode(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
331349
>>> data_bmode = pygmt.blockmode(data=data, region=[245, 255, 20, 30], spacing="5m")
332350
"""
333351
return _blockm(
334-
block_method="blockmode", data=data, x=x, y=y, z=z, outfile=outfile, **kwargs
352+
block_method="blockmode",
353+
data=data,
354+
x=x,
355+
y=y,
356+
z=z,
357+
output_type=output_type,
358+
outfile=outfile,
359+
**kwargs,
335360
)

pygmt/src/filter1d.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
filter1d - Time domain filtering of 1-D data tables
33
"""
44

5-
import pandas as pd
65
from pygmt.clib import Session
76
from pygmt.exceptions import GMTInvalidInput
87
from pygmt.helpers import (
9-
GMTTempFile,
108
build_arg_string,
119
fmt_docstring,
1210
use_alias,
@@ -117,22 +115,13 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
117115

118116
output_type = validate_output_table_type(output_type, outfile=outfile)
119117

120-
with GMTTempFile() as tmpfile:
121-
with Session() as lib:
122-
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
123-
if outfile is None:
124-
outfile = tmpfile.name
125-
lib.call_module(
126-
module="filter1d",
127-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
128-
)
129-
130-
# Read temporary csv output to a pandas table
131-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
132-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
133-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
134-
result = None
135-
136-
if output_type == "numpy":
137-
result = result.to_numpy()
138-
return result
118+
with Session() as lib:
119+
with (
120+
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
121+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
122+
):
123+
lib.call_module(
124+
module="filter1d",
125+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
126+
)
127+
return lib.return_table(output_type=output_type, vfile=vouttbl)

pygmt/src/grd2xyz.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
grd2xyz - Convert grid to data table
33
"""
44

5-
import pandas as pd
65
import xarray as xr
76
from pygmt.clib import Session
87
from pygmt.exceptions import GMTInvalidInput
98
from pygmt.helpers import (
10-
GMTTempFile,
119
build_arg_string,
1210
fmt_docstring,
1311
kwargs_to_strings,
@@ -150,30 +148,23 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
150148
)
151149

152150
# Set the default column names for the pandas dataframe header
153-
dataframe_header = ["x", "y", "z"]
151+
column_names = ["x", "y", "z"]
154152
# Let output pandas column names match input DataArray dimension names
155-
if isinstance(grid, xr.DataArray) and output_type == "pandas":
153+
if isinstance(grid, xr.DataArray):
156154
# Reverse the dims because it is rows, columns ordered.
157-
dataframe_header = [grid.dims[1], grid.dims[0], grid.name]
158-
159-
with GMTTempFile() as tmpfile:
160-
with Session() as lib:
161-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
162-
if outfile is None:
163-
outfile = tmpfile.name
164-
lib.call_module(
165-
module="grd2xyz",
166-
args=build_arg_string(kwargs, infile=vingrd, outfile=outfile),
167-
)
168-
169-
# Read temporary csv output to a pandas table
170-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
171-
result = pd.read_csv(
172-
tmpfile.name, sep="\t", names=dataframe_header, comment=">"
155+
column_names = [grid.dims[1], grid.dims[0], grid.name]
156+
157+
with Session() as lib:
158+
with (
159+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
160+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
161+
):
162+
lib.call_module(
163+
module="grd2xyz",
164+
args=build_arg_string(kwargs, infile=vingrd, outfile=vouttbl),
173165
)
174-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
175-
result = None
176-
177-
if output_type == "numpy":
178-
result = result.to_numpy()
179-
return result
166+
return lib.return_table(
167+
output_type=output_type,
168+
vfile=vouttbl,
169+
column_names=column_names,
170+
)

pygmt/src/grdhisteq.py

+20-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import numpy as np
6-
import pandas as pd
76
from pygmt.clib import Session
87
from pygmt.exceptions import GMTInvalidInput
98
from pygmt.helpers import (
@@ -231,33 +230,28 @@ def compute_bins(grid, output_type="pandas", **kwargs):
231230
if kwargs.get("h") is not None and output_type != "file":
232231
raise GMTInvalidInput("'header' is only allowed with output_type='file'.")
233232

234-
with GMTTempFile(suffix=".txt") as tmpfile:
235-
with Session() as lib:
236-
with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd:
237-
if outfile is None:
238-
kwargs["D"] = outfile = tmpfile.name # output to tmpfile
239-
lib.call_module(
240-
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
241-
)
233+
with Session() as lib:
234+
with (
235+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
236+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
237+
):
238+
kwargs["D"] = vouttbl # -D for output file name
239+
lib.call_module(
240+
module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd)
241+
)
242242

243-
if outfile == tmpfile.name:
244-
# if user did not set outfile, return pd.DataFrame
245-
result = pd.read_csv(
246-
filepath_or_buffer=outfile,
247-
sep="\t",
248-
header=None,
249-
names=["start", "stop", "bin_id"],
250-
dtype={
243+
result = lib.return_table(
244+
output_type=output_type,
245+
vfile=vouttbl,
246+
column_names=["start", "stop", "bin_id"],
247+
)
248+
if output_type == "pandas":
249+
result = result.astype(
250+
{
251251
"start": np.float32,
252252
"stop": np.float32,
253253
"bin_id": np.uint32,
254-
},
254+
}
255255
)
256-
elif outfile != tmpfile.name:
257-
# return None if outfile set, output in outfile
258-
return None
259-
260-
if output_type == "numpy":
261-
return result.to_numpy()
262-
263-
return result.set_index("bin_id")
256+
return result.set_index("bin_id")
257+
return result

pygmt/src/grdtrack.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from pygmt.clib import Session
77
from pygmt.exceptions import GMTInvalidInput
88
from pygmt.helpers import (
9-
GMTTempFile,
109
build_arg_string,
1110
fmt_docstring,
1211
kwargs_to_strings,
1312
use_alias,
13+
validate_output_table_type,
1414
)
1515

1616
__doctest_skip__ = ["grdtrack"]
@@ -44,7 +44,9 @@
4444
w="wrap",
4545
)
4646
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
47-
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
47+
def grdtrack(
48+
grid, points=None, output_type="pandas", outfile=None, newcolname=None, **kwargs
49+
):
4850
r"""
4951
Sample grids at specified (x,y) locations.
5052
@@ -291,30 +293,27 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
291293
if hasattr(points, "columns") and newcolname is None:
292294
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
293295

294-
with GMTTempFile(suffix=".csv") as tmpfile:
295-
with Session() as lib:
296-
with (
297-
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
298-
lib.virtualfile_in(
299-
check_kind="vector", data=points, required_data=False
300-
) as vintbl,
301-
):
302-
kwargs["G"] = vingrd
303-
if outfile is None: # Output to tmpfile if outfile is not set
304-
outfile = tmpfile.name
305-
lib.call_module(
306-
module="grdtrack",
307-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
308-
)
296+
output_type = validate_output_table_type(output_type, outfile=outfile)
309297

310-
# Read temporary csv output to a pandas table
311-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
312-
try:
313-
column_names = [*points.columns.to_list(), newcolname]
314-
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
315-
except AttributeError: # 'str' object has no attribute 'columns'
316-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
317-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
318-
result = None
298+
column_names = None
299+
if isinstance(points, pd.DataFrame):
300+
column_names = [*points.columns.to_list(), newcolname]
319301

320-
return result
302+
with Session() as lib:
303+
with (
304+
lib.virtualfile_in(check_kind="raster", data=grid) as vingrd,
305+
lib.virtualfile_in(
306+
check_kind="vector", data=points, required_data=False
307+
) as vintbl,
308+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
309+
):
310+
kwargs["G"] = vingrd
311+
lib.call_module(
312+
module="grdtrack",
313+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
314+
)
315+
return lib.return_table(
316+
output_type=output_type,
317+
vfile=vouttbl,
318+
column_names=column_names,
319+
)

0 commit comments

Comments
 (0)