Skip to content

Commit 9ec92be

Browse files
authored
Refactor the data_kind and validate_data_input functions (#3335)
1 parent d3101f3 commit 9ec92be

File tree

7 files changed

+44
-69
lines changed

7 files changed

+44
-69
lines changed

pygmt/clib/session.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
GMTVersionError,
3535
)
3636
from pygmt.helpers import (
37+
_validate_data_input,
3738
data_kind,
3839
tempfile_from_geojson,
3940
tempfile_from_image,
@@ -1684,8 +1685,15 @@ def virtualfile_in( # noqa: PLR0912
16841685
... print(fout.read().strip())
16851686
<vector memory>: N = 3 <7/9> <4/6> <1/3>
16861687
"""
1687-
kind = data_kind(
1688-
data, x, y, z, required_z=required_z, required_data=required_data
1688+
kind = data_kind(data, required=required_data)
1689+
_validate_data_input(
1690+
data=data,
1691+
x=x,
1692+
y=y,
1693+
z=z,
1694+
required_z=required_z,
1695+
required_data=required_data,
1696+
kind=kind,
16891697
)
16901698

16911699
if check_kind:

pygmt/helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
unique_name,
1616
)
1717
from pygmt.helpers.utils import (
18+
_validate_data_input,
1819
args_in_kwargs,
1920
build_arg_list,
2021
build_arg_string,

pygmt/helpers/utils.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import warnings
1313
import webbrowser
1414
from collections.abc import Iterable, Sequence
15-
from typing import Any
15+
from typing import Any, Literal
1616

1717
import xarray as xr
1818
from pygmt.encodings import charset
@@ -79,6 +79,10 @@ def _validate_data_input(
7979
Traceback (most recent call last):
8080
...
8181
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
82+
>>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6])
83+
Traceback (most recent call last):
84+
...
85+
pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z.
8286
>>> _validate_data_input(data="infile", z=[7, 8, 9])
8387
Traceback (most recent call last):
8488
...
@@ -111,77 +115,69 @@ def _validate_data_input(
111115
raise GMTInvalidInput("data must provide x, y, and z columns.")
112116

113117

114-
def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True):
118+
def data_kind(
119+
data: Any = None, required: bool = True
120+
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
115121
"""
116-
Check what kind of data is provided to a module.
122+
Check the kind of data that is provided to a module.
117123
118-
Possible types:
124+
The ``data`` argument can be in any type, but only following types are supported:
119125
120-
* a file name provided as 'data'
121-
* a pathlib.PurePath object provided as 'data'
122-
* an xarray.DataArray object provided as 'data'
123-
* a 2-D matrix provided as 'data'
124-
* 1-D arrays x and y (and z, optionally)
125-
* an optional argument (None, bool, int or float) provided as 'data'
126-
127-
Arguments should be ``None`` if not used. If doesn't fit any of these
128-
categories (or fits more than one), will raise an exception.
126+
- a string or a :class:`pathlib.PurePath` object or a sequence of them, representing
127+
a file name or a list of file names
128+
- a 2-D or 3-D :class:`xarray.DataArray` object
129+
- a 2-D matrix
130+
- None, bool, int or float type representing an optional arguments
131+
- a geo-like Python object that implements ``__geo_interface__`` (e.g.,
132+
geopandas.GeoDataFrame or shapely.geometry)
129133
130134
Parameters
131135
----------
132136
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
133137
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
134138
table, an :class:`xarray.DataArray`, a 1-D/2-D
135139
{table-classes} or an option argument.
136-
x/y : 1-D arrays or None
137-
x and y columns as numpy arrays.
138-
z : 1-D array or None
139-
z column as numpy array. To be used optionally when x and y are given.
140-
required_z : bool
141-
State whether the 'z' column is required.
142-
required_data : bool
140+
required
143141
Set to True when 'data' is required, or False when dealing with
144142
optional virtual files. [Default is True].
145143
146144
Returns
147145
-------
148-
kind : str
149-
One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``,
150-
``'matrix'``, or ``'vectors'``.
146+
kind
147+
The data kind.
151148
152149
Examples
153150
--------
154-
155151
>>> import numpy as np
156152
>>> import xarray as xr
157153
>>> import pathlib
158-
>>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6]))
154+
>>> data_kind(data=None)
159155
'vectors'
160-
>>> data_kind(data=np.arange(10).reshape((5, 2)), x=None, y=None)
156+
>>> data_kind(data=np.arange(10).reshape((5, 2)))
161157
'matrix'
162-
>>> data_kind(data="my-data-file.txt", x=None, y=None)
158+
>>> data_kind(data="my-data-file.txt")
163159
'file'
164-
>>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None)
160+
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
165161
'file'
166-
>>> data_kind(data=None, x=None, y=None, required_data=False)
162+
>>> data_kind(data=None, required=False)
167163
'arg'
168-
>>> data_kind(data=2.0, x=None, y=None, required_data=False)
164+
>>> data_kind(data=2.0, required=False)
169165
'arg'
170-
>>> data_kind(data=True, x=None, y=None, required_data=False)
166+
>>> data_kind(data=True, required=False)
171167
'arg'
172168
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
173169
'grid'
174170
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
175171
'image'
176172
"""
177-
# determine the data kind
173+
kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]
178174
if isinstance(data, str | pathlib.PurePath) or (
179175
isinstance(data, list | tuple)
180176
and all(isinstance(_file, str | pathlib.PurePath) for _file in data)
181177
):
182178
# One or more files
183179
kind = "file"
184-
elif isinstance(data, bool | int | float) or (data is None and not required_data):
180+
elif isinstance(data, bool | int | float) or (data is None and not required):
185181
kind = "arg"
186182
elif isinstance(data, xr.DataArray):
187183
kind = "image" if len(data.dims) == 3 else "grid"
@@ -193,15 +189,6 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data
193189
kind = "matrix"
194190
else:
195191
kind = "vectors"
196-
_validate_data_input(
197-
data=data,
198-
x=x,
199-
y=y,
200-
z=z,
201-
required_z=required_z,
202-
required_data=required_data,
203-
kind=kind,
204-
)
205192
return kind
206193

207194

pygmt/src/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def plot( # noqa: PLR0912
208208
"""
209209
kwargs = self._preprocess(**kwargs)
210210

211-
kind = data_kind(data, x, y)
211+
kind = data_kind(data)
212212
extra_arrays = []
213213
if kind == "vectors": # Add more columns for vectors input
214214
# Parameters for vector styles

pygmt/src/plot3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def plot3d( # noqa: PLR0912
183183
"""
184184
kwargs = self._preprocess(**kwargs)
185185

186-
kind = data_kind(data, x, y, z)
186+
kind = data_kind(data)
187187
extra_arrays = []
188188

189189
if kind == "vectors": # Add more columns for vectors input

pygmt/src/text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ def text_( # noqa: PLR0912
180180

181181
# Ensure inputs are either textfiles, x/y/text, or position/text
182182
if position is None:
183-
if (x is not None or y is not None) and textfiles is not None:
183+
if any(v is not None for v in (x, y, text)) and textfiles is not None:
184184
raise GMTInvalidInput(
185185
"Provide either position only, or x/y pairs, or textfiles."
186186
)
187-
kind = data_kind(textfiles, x, y, text)
187+
kind = data_kind(textfiles)
188188
if kind == "vectors" and text is None:
189189
raise GMTInvalidInput("Must provide text with x/y pairs")
190190
else:

pygmt/tests/test_helpers.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from pathlib import Path
66

7-
import numpy as np
87
import pytest
98
import xarray as xr
109
from pygmt import Figure
@@ -13,7 +12,6 @@
1312
GMTTempFile,
1413
args_in_kwargs,
1514
build_arg_list,
16-
data_kind,
1715
kwargs_to_strings,
1816
unique_name,
1917
)
@@ -33,25 +31,6 @@ def test_load_static_earth_relief():
3331
assert isinstance(data, xr.DataArray)
3432

3533

36-
@pytest.mark.parametrize(
37-
("data", "x", "y"),
38-
[
39-
(None, None, None),
40-
("data.txt", np.array([1, 2]), np.array([4, 5])),
41-
("data.txt", np.array([1, 2]), None),
42-
("data.txt", None, np.array([4, 5])),
43-
(None, np.array([1, 2]), None),
44-
(None, None, np.array([4, 5])),
45-
],
46-
)
47-
def test_data_kind_fails(data, x, y):
48-
"""
49-
Make sure data_kind raises exceptions when it should.
50-
"""
51-
with pytest.raises(GMTInvalidInput):
52-
data_kind(data=data, x=x, y=y)
53-
54-
5534
def test_unique_name():
5635
"""
5736
Make sure the names are really unique.

0 commit comments

Comments
 (0)