Skip to content

Commit bf7b9a1

Browse files
authored
pygmt.filter1d: Improve performance by storing output in virtual files (#3085)
1 parent 2f598c5 commit bf7b9a1

File tree

2 files changed

+29
-111
lines changed

2 files changed

+29
-111
lines changed

pygmt/src/filter1d.py

+25-37
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
filter1d - Time domain filtering of 1-D data tables
33
"""
44

5+
from typing import Literal
6+
57
import pandas as pd
8+
import xarray as xr
69
from pygmt.clib import Session
710
from pygmt.exceptions import GMTInvalidInput
811
from pygmt.helpers import (
9-
GMTTempFile,
1012
build_arg_string,
1113
fmt_docstring,
1214
use_alias,
@@ -20,7 +22,12 @@
2022
F="filter_type",
2123
N="time_col",
2224
)
23-
def filter1d(data, output_type="pandas", outfile=None, **kwargs):
25+
def filter1d(
26+
data,
27+
output_type: Literal["pandas", "numpy", "file"] = "pandas",
28+
outfile: str | None = None,
29+
**kwargs,
30+
) -> pd.DataFrame | xr.DataArray | None:
2431
r"""
2532
Time domain filtering of 1-D data tables.
2633
@@ -38,6 +45,8 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
3845
3946
Parameters
4047
----------
48+
{output_type}
49+
{outfile}
4150
filter_type : str
4251
**type**\ *width*\ [**+h**].
4352
Set the filter **type**. Choose among convolution and non-convolution
@@ -91,48 +100,27 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
91100
left-most column is 0, while the right-most is (*n_cols* - 1)
92101
[Default is ``0``].
93102
94-
output_type : str
95-
Determine the format the xyz data will be returned in [Default is
96-
``pandas``]:
97-
98-
- ``numpy`` - :class:`numpy.ndarray`
99-
- ``pandas``- :class:`pandas.DataFrame`
100-
- ``file`` - ASCII file (requires ``outfile``)
101-
outfile : str
102-
The file name for the output ASCII file.
103-
104103
Returns
105104
-------
106-
ret : pandas.DataFrame or numpy.ndarray or None
105+
ret
107106
Return type depends on ``outfile`` and ``output_type``:
108107
109-
- None if ``outfile`` is set (output will be stored in file set by
110-
``outfile``)
111-
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is
112-
not set (depends on ``output_type`` [Default is
113-
:class:`pandas.DataFrame`])
108+
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
109+
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
110+
(depends on ``output_type``)
114111
"""
115112
if kwargs.get("F") is None:
116113
raise GMTInvalidInput("Pass a required argument to 'filter_type'.")
117114

118115
output_type = validate_output_table_type(output_type, outfile=outfile)
119116

120-
with GMTTempFile() as tmpfile:
121-
with Session() as lib:
122-
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
123-
if outfile is None:
124-
outfile = tmpfile.name
125-
lib.call_module(
126-
module="filter1d",
127-
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
128-
)
129-
130-
# Read temporary csv output to a pandas table
131-
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
132-
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
133-
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
134-
result = None
135-
136-
if output_type == "numpy":
137-
result = result.to_numpy()
138-
return result
117+
with Session() as lib:
118+
with (
119+
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
120+
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
121+
):
122+
lib.call_module(
123+
module="filter1d",
124+
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
125+
)
126+
return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl)

pygmt/tests/test_filter1d.py

+4-74
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
Test pygmt.filter1d.
33
"""
44

5-
from pathlib import Path
6-
7-
import numpy as np
85
import pandas as pd
96
import pytest
107
from pygmt import filter1d
118
from pygmt.datasets import load_sample_data
12-
from pygmt.exceptions import GMTInvalidInput
13-
from pygmt.helpers import GMTTempFile
149

1510

1611
@pytest.fixture(scope="module", name="data")
@@ -21,76 +16,11 @@ def fixture_data():
2116
return load_sample_data(name="maunaloa_co2")
2217

2318

24-
def test_filter1d_no_outfile(data):
19+
@pytest.mark.benchmark
20+
def test_filter1d(data):
2521
"""
26-
Test filter1d with no set outfile.
22+
Test the basic functionality of filter1d.
2723
"""
2824
result = filter1d(data=data, filter_type="g5")
25+
assert isinstance(result, pd.DataFrame)
2926
assert result.shape == (671, 2)
30-
31-
32-
def test_filter1d_file_output(data):
33-
"""
34-
Test that filter1d returns a file output when it is specified.
35-
"""
36-
with GMTTempFile(suffix=".txt") as tmpfile:
37-
result = filter1d(
38-
data=data, filter_type="g5", outfile=tmpfile.name, output_type="file"
39-
)
40-
assert result is None # return value is None
41-
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists
42-
43-
44-
def test_filter1d_invalid_format(data):
45-
"""
46-
Test that filter1d fails with an incorrect format for output_type.
47-
"""
48-
with pytest.raises(GMTInvalidInput):
49-
filter1d(data=data, filter_type="g5", output_type="a")
50-
51-
52-
def test_filter1d_no_filter(data):
53-
"""
54-
Test that filter1d fails with an argument is missing for filter.
55-
"""
56-
with pytest.raises(GMTInvalidInput):
57-
filter1d(data=data)
58-
59-
60-
def test_filter1d_no_outfile_specified(data):
61-
"""
62-
Test that filter1d fails when outpput_type is set to 'file' but no output file name
63-
is specified.
64-
"""
65-
with pytest.raises(GMTInvalidInput):
66-
filter1d(data=data, filter_type="g5", output_type="file")
67-
68-
69-
def test_filter1d_outfile_incorrect_output_type(data):
70-
"""
71-
Test that filter1d raises a warning when an outfile filename is set but the
72-
output_type is not set to 'file'.
73-
"""
74-
with pytest.warns(RuntimeWarning):
75-
with GMTTempFile(suffix=".txt") as tmpfile:
76-
result = filter1d(
77-
data=data, filter_type="g5", outfile=tmpfile.name, output_type="numpy"
78-
)
79-
assert result is None # return value is None
80-
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists
81-
82-
83-
@pytest.mark.benchmark
84-
def test_filter1d_format(data):
85-
"""
86-
Test that correct formats are returned.
87-
"""
88-
time_series_default = filter1d(data=data, filter_type="g5")
89-
assert isinstance(time_series_default, pd.DataFrame)
90-
assert time_series_default.shape == (671, 2)
91-
time_series_array = filter1d(data=data, filter_type="g5", output_type="numpy")
92-
assert isinstance(time_series_array, np.ndarray)
93-
assert time_series_array.shape == (671, 2)
94-
time_series_df = filter1d(data=data, filter_type="g5", output_type="pandas")
95-
assert isinstance(time_series_df, pd.DataFrame)
96-
assert time_series_df.shape == (671, 2)

0 commit comments

Comments
 (0)