Skip to content

Commit 68a17a0

Browse files
seismanweiji14
andauthored
data_kind: Add more tests to demonstrate the data kind of various data types (#3480)
Co-authored-by: Wei Ji <[email protected]>
1 parent 2d1a8cc commit 68a17a0

10 files changed

+89
-48
lines changed

pygmt/helpers/utils.py

+62-22
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,11 @@ def data_kind(
207207
208208
Parameters
209209
----------
210-
data : str, pathlib.PurePath, None, bool, xarray.DataArray or {table-like}
211-
Pass in either a file name or :class:`pathlib.Path` to an ASCII data
212-
table, an :class:`xarray.DataArray`, a 1-D/2-D
213-
{table-classes} or an option argument.
210+
data
211+
The data to be passed to a GMT module.
214212
required
215-
Set to True when 'data' is required, or False when dealing with
216-
optional virtual files. [Default is True].
213+
Whether 'data' is required. Set to ``False`` when dealing with optional virtual
214+
files.
217215
218216
Returns
219217
-------
@@ -222,30 +220,72 @@ def data_kind(
222220
223221
Examples
224222
--------
223+
>>> import io
224+
>>> from pathlib import Path
225225
>>> import numpy as np
226+
>>> import pandas as pd
226227
>>> import xarray as xr
227-
>>> import pathlib
228-
>>> import io
229-
>>> data_kind(data=None)
230-
'vectors'
231-
>>> data_kind(data=np.arange(10).reshape((5, 2)))
232-
'matrix'
233-
>>> data_kind(data="my-data-file.txt")
234-
'file'
235-
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
236-
'file'
228+
229+
The "arg" kind:
230+
231+
>>> [data_kind(data=data, required=False) for data in (2, 2.0, True, False)]
232+
['arg', 'arg', 'arg', 'arg']
237233
>>> data_kind(data=None, required=False)
238234
'arg'
239-
>>> data_kind(data=2.0, required=False)
240-
'arg'
241-
>>> data_kind(data=True, required=False)
242-
'arg'
243-
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3)))
235+
236+
The "file" kind:
237+
238+
>>> [data_kind(data=data) for data in ("file.txt", ("file1.txt", "file2.txt"))]
239+
['file', 'file']
240+
>>> data_kind(data=Path("file.txt"))
241+
'file'
242+
>>> data_kind(data=(Path("file1.txt"), Path("file2.txt")))
243+
'file'
244+
245+
The "grid" kind:
246+
247+
>>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) # 2-D xarray.DataArray
248+
'grid'
249+
>>> data_kind(data=xr.DataArray(np.arange(12))) # 1-D xarray.DataArray
250+
'grid'
251+
>>> data_kind(data=xr.DataArray(np.random.rand(2, 3, 4, 5))) # 4-D xarray.DataArray
244252
'grid'
245-
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
253+
254+
The "image" kind:
255+
256+
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5))) # 3-D xarray.DataArray
246257
'image'
258+
259+
The "stringio"`` kind:
260+
247261
>>> data_kind(data=io.StringIO("TEXT1\nTEXT23\n"))
248262
'stringio'
263+
264+
The "matrix"`` kind:
265+
266+
>>> data_kind(data=np.arange(10)) # 1-D numpy.ndarray
267+
'matrix'
268+
>>> data_kind(data=np.arange(10).reshape((5, 2))) # 2-D numpy.ndarray
269+
'matrix'
270+
>>> data_kind(data=np.arange(60).reshape((3, 4, 5))) # 3-D numpy.ndarray
271+
'matrix'
272+
>>> data_kind(xr.DataArray(np.arange(12), name="x").to_dataset()) # xarray.Dataset
273+
'matrix'
274+
>>> data_kind(data=[1, 2, 3]) # 1-D sequence
275+
'matrix'
276+
>>> data_kind(data=[[1, 2, 3], [4, 5, 6]]) # sequence of sequences
277+
'matrix'
278+
>>> data_kind(data={"x": [1, 2, 3], "y": [4, 5, 6]}) # dictionary
279+
'matrix'
280+
>>> data_kind(data=pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) # pd.DataFrame
281+
'matrix'
282+
>>> data_kind(data=pd.Series([1, 2, 3], name="x")) # pd.Series
283+
'matrix'
284+
285+
The "vectors" kind:
286+
287+
>>> data_kind(data=None)
288+
'vectors'
249289
"""
250290
kind: Literal[
251291
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"

pygmt/tests/test_blockm.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pygmt import blockmean, blockmode
1313
from pygmt.datasets import load_sample_data
1414
from pygmt.exceptions import GMTInvalidInput
15-
from pygmt.helpers import GMTTempFile, data_kind
15+
from pygmt.helpers import GMTTempFile
1616

1717

1818
@pytest.fixture(scope="module", name="dataframe")
@@ -68,7 +68,6 @@ def test_blockmean_wrong_kind_of_input_table_grid(dataframe):
6868
Run blockmean using table input that is not a pandas.DataFrame or file but a grid.
6969
"""
7070
invalid_table = dataframe.bathymetry.to_xarray()
71-
assert data_kind(invalid_table) == "grid"
7271
with pytest.raises(GMTInvalidInput):
7372
blockmean(data=invalid_table, spacing="5m", region=[245, 255, 20, 30])
7473

pygmt/tests/test_blockmedian.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pygmt import blockmedian
1111
from pygmt.datasets import load_sample_data
1212
from pygmt.exceptions import GMTInvalidInput
13-
from pygmt.helpers import GMTTempFile, data_kind
13+
from pygmt.helpers import GMTTempFile
1414

1515

1616
@pytest.fixture(scope="module", name="dataframe")
@@ -65,7 +65,6 @@ def test_blockmedian_wrong_kind_of_input_table_grid(dataframe):
6565
Run blockmedian using table input that is not a pandas.DataFrame or file but a grid.
6666
"""
6767
invalid_table = dataframe.bathymetry.to_xarray()
68-
assert data_kind(invalid_table) == "grid"
6968
with pytest.raises(GMTInvalidInput):
7069
blockmedian(data=invalid_table, spacing="5m", region=[245, 255, 20, 30])
7170

pygmt/tests/test_geopandas.py

+16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
from pygmt import Figure, info, makecpt, which
9+
from pygmt.helpers import data_kind
910
from pygmt.helpers.testing import skip_if_no
1011

1112
gpd = pytest.importorskip("geopandas")
@@ -243,3 +244,18 @@ def test_geopandas_plot_int64_as_float(gdf_ridge):
243244
makecpt(cmap="lisbon", series=[10, 60, 10], continuous=True)
244245
fig.colorbar()
245246
return fig
247+
248+
249+
def test_geopandas_data_kind_geopandas(gdf):
250+
"""
251+
Check if geopandas.GeoDataFrame object is recognized as a "geojson" kind.
252+
"""
253+
assert data_kind(data=gdf) == "geojson"
254+
255+
256+
def test_geopandas_data_kind_shapely():
257+
"""
258+
Check if shapely.geometry object is recognized as a "geojson" kind.
259+
"""
260+
polygon = shapely.geometry.Polygon([(20, 10), (23, 10), (23, 14), (20, 14)])
261+
assert data_kind(data=polygon) == "geojson"

pygmt/tests/test_grdtrack.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111
from pygmt import grdtrack
1212
from pygmt.exceptions import GMTInvalidInput
13-
from pygmt.helpers import GMTTempFile, data_kind
13+
from pygmt.helpers import GMTTempFile
1414
from pygmt.helpers.testing import load_static_earth_relief
1515

1616
POINTS_DATA = Path(__file__).parent / "data" / "track.txt"
@@ -126,22 +126,18 @@ def test_grdtrack_profile(dataarray):
126126

127127
def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
128128
"""
129-
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file.
129+
Run grdtrack using points input that is not a pandas.DataFrame or file.
130130
"""
131131
invalid_points = dataframe.longitude.to_xarray()
132-
133-
assert data_kind(invalid_points) == "grid"
134132
with pytest.raises(GMTInvalidInput):
135133
grdtrack(points=invalid_points, grid=dataarray, newcolname="bathymetry")
136134

137135

138136
def test_grdtrack_wrong_kind_of_grid_input(dataarray, dataframe):
139137
"""
140-
Run grdtrack using grid input that is not as xarray.DataArray (grid) or file.
138+
Run grdtrack using grid input that is not an xarray.DataArray or file.
141139
"""
142140
invalid_grid = dataarray.to_dataset()
143-
144-
assert data_kind(invalid_grid) == "matrix"
145141
with pytest.raises(GMTInvalidInput):
146142
grdtrack(points=dataframe, grid=invalid_grid, newcolname="bathymetry")
147143

pygmt/tests/test_grdview.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from pygmt import Figure, grdcut
77
from pygmt.exceptions import GMTInvalidInput
8-
from pygmt.helpers import GMTTempFile, data_kind
8+
from pygmt.helpers import GMTTempFile
99
from pygmt.helpers.testing import load_static_earth_relief
1010

1111

@@ -58,8 +58,6 @@ def test_grdview_wrong_kind_of_grid(xrgrid):
5858
Run grdview using grid input that is not an xarray.DataArray or file.
5959
"""
6060
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
61-
assert data_kind(dataset) == "matrix"
62-
6361
fig = Figure()
6462
with pytest.raises(GMTInvalidInput):
6563
fig.grdview(grid=dataset)
@@ -238,8 +236,6 @@ def test_grdview_wrong_kind_of_drapegrid(xrgrid):
238236
Run grdview using drapegrid input that is not an xarray.DataArray or file.
239237
"""
240238
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
241-
assert data_kind(dataset) == "matrix"
242-
243239
fig = Figure()
244240
with pytest.raises(GMTInvalidInput):
245241
fig.grdview(grid=xrgrid, drapegrid=dataset)

pygmt/tests/test_nearneighbor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pygmt import nearneighbor
1212
from pygmt.datasets import load_sample_data
1313
from pygmt.exceptions import GMTInvalidInput
14-
from pygmt.helpers import GMTTempFile, data_kind
14+
from pygmt.helpers import GMTTempFile
1515

1616

1717
@pytest.fixture(scope="module", name="ship_data")
@@ -61,7 +61,6 @@ def test_nearneighbor_wrong_kind_of_input(ship_data):
6161
Run nearneighbor using grid input that is not file/matrix/vectors.
6262
"""
6363
data = ship_data.bathymetry.to_xarray() # convert pandas.Series to xarray.DataArray
64-
assert data_kind(data) == "grid"
6564
with pytest.raises(GMTInvalidInput):
6665
nearneighbor(
6766
data=data, spacing="5m", region=[245, 255, 20, 30], search_radius="10m"

pygmt/tests/test_surface.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import xarray as xr
1010
from pygmt import surface, which
1111
from pygmt.exceptions import GMTInvalidInput
12-
from pygmt.helpers import GMTTempFile, data_kind
12+
from pygmt.helpers import GMTTempFile
1313

1414

1515
@pytest.fixture(scope="module", name="data")
@@ -125,7 +125,6 @@ def test_surface_wrong_kind_of_input(data, region, spacing):
125125
Run surface using grid input that is not file/matrix/vectors.
126126
"""
127127
data = data.z.to_xarray() # convert pandas.Series to xarray.DataArray
128-
assert data_kind(data) == "grid"
129128
with pytest.raises(GMTInvalidInput):
130129
surface(data=data, spacing=spacing, region=region)
131130

pygmt/tests/test_triangulate.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import xarray as xr
1111
from pygmt import triangulate, which
1212
from pygmt.exceptions import GMTInvalidInput
13-
from pygmt.helpers import GMTTempFile, data_kind
13+
from pygmt.helpers import GMTTempFile
1414

1515

1616
@pytest.fixture(scope="module", name="dataframe")
@@ -93,7 +93,6 @@ def test_delaunay_triples_wrong_kind_of_input(dataframe):
9393
Run triangulate.delaunay_triples using grid input that is not file/matrix/vectors.
9494
"""
9595
data = dataframe.z.to_xarray() # convert pandas.Series to xarray.DataArray
96-
assert data_kind(data) == "grid"
9796
with pytest.raises(GMTInvalidInput):
9897
triangulate.delaunay_triples(data=data)
9998

pygmt/tests/test_x2sys_cross.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from pygmt.clib import __gmt_version__
1717
from pygmt.datasets import load_sample_data
1818
from pygmt.exceptions import GMTInvalidInput
19-
from pygmt.helpers import data_kind
2019

2120

2221
@pytest.fixture(name="mock_x2sys_home")
@@ -237,11 +236,10 @@ def test_x2sys_cross_input_two_filenames():
237236

238237
def test_x2sys_cross_invalid_tracks_input_type(tracks):
239238
"""
240-
Run x2sys_cross using tracks input that is not a pandas.DataFrame (matrix) or str
241-
(file) type, which would raise a GMTInvalidInput error.
239+
Run x2sys_cross using tracks input that is not a pandas.DataFrame or str type,
240+
which would raise a GMTInvalidInput error.
242241
"""
243242
invalid_tracks = tracks[0].to_xarray().z
244-
assert data_kind(invalid_tracks) == "grid"
245243
with pytest.raises(GMTInvalidInput):
246244
x2sys_cross(tracks=[invalid_tracks])
247245

0 commit comments

Comments
 (0)