Skip to content

Commit 7b43427

Browse files
seismanweiji14
andcommitted
clib.converison._to_numpy: Add tests for numpy arrays of numpy numeric dtypes (#3583)
Co-authored-by: Wei Ji <[email protected]>
1 parent 9e196f3 commit 7b43427

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

pygmt/tests/test_clib_to_numpy.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""
2+
Tests for the _to_numpy function in the clib.conversion module.
3+
"""
4+
5+
import sys
6+
7+
import numpy as np
8+
import numpy.testing as npt
9+
import pandas as pd
10+
import pytest
11+
from packaging.version import Version
12+
from pygmt.clib.conversion import _to_numpy
13+
14+
15+
def _check_result(result, expected_dtype):
16+
"""
17+
A helper function to check if the result of the _to_numpy function is a C-contiguous
18+
NumPy array with the expected dtype.
19+
"""
20+
assert isinstance(result, np.ndarray)
21+
assert result.flags.c_contiguous
22+
assert result.dtype.type == expected_dtype
23+
24+
25+
########################################################################################
26+
# Test the _to_numpy function with Python built-in types.
27+
########################################################################################
28+
@pytest.mark.parametrize(
29+
("data", "expected_dtype"),
30+
[
31+
pytest.param(
32+
[1, 2, 3],
33+
np.int32
34+
if sys.platform == "win32" and Version(np.__version__) < Version("2.0")
35+
else np.int64,
36+
id="int",
37+
),
38+
pytest.param([1.0, 2.0, 3.0], np.float64, id="float"),
39+
pytest.param(
40+
[complex(+1), complex(-2j), complex("-Infinity+NaNj")],
41+
np.complex128,
42+
id="complex",
43+
),
44+
],
45+
)
46+
def test_to_numpy_python_types_numeric(data, expected_dtype):
47+
"""
48+
Test the _to_numpy function with Python built-in numeric types.
49+
"""
50+
result = _to_numpy(data)
51+
_check_result(result, expected_dtype)
52+
npt.assert_array_equal(result, data)
53+
54+
55+
########################################################################################
56+
# Test the _to_numpy function with NumPy arrays.
57+
#
58+
# There are 24 fundamental dtypes in NumPy. Not all of them are supported by PyGMT.
59+
#
60+
# - Numeric dtypes:
61+
# - int8, int16, int32, int64, longlong
62+
# - uint8, uint16, uint32, uint64, ulonglong
63+
# - float16, float32, float64, longdouble
64+
# - complex64, complex128, clongdouble
65+
# - bool
66+
# - datetime64, timedelta64
67+
# - str_
68+
# - bytes_
69+
# - object_
70+
# - void
71+
#
72+
# Reference: https://numpy.org/doc/2.1/reference/arrays.scalars.html
73+
########################################################################################
74+
np_dtype_params = [
75+
pytest.param(np.int8, np.int8, id="int8"),
76+
pytest.param(np.int16, np.int16, id="int16"),
77+
pytest.param(np.int32, np.int32, id="int32"),
78+
pytest.param(np.int64, np.int64, id="int64"),
79+
pytest.param(np.longlong, np.longlong, id="longlong"),
80+
pytest.param(np.uint8, np.uint8, id="uint8"),
81+
pytest.param(np.uint16, np.uint16, id="uint16"),
82+
pytest.param(np.uint32, np.uint32, id="uint32"),
83+
pytest.param(np.uint64, np.uint64, id="uint64"),
84+
pytest.param(np.ulonglong, np.ulonglong, id="ulonglong"),
85+
pytest.param(np.float16, np.float16, id="float16"),
86+
pytest.param(np.float32, np.float32, id="float32"),
87+
pytest.param(np.float64, np.float64, id="float64"),
88+
pytest.param(np.longdouble, np.longdouble, id="longdouble"),
89+
pytest.param(np.complex64, np.complex64, id="complex64"),
90+
pytest.param(np.complex128, np.complex128, id="complex128"),
91+
pytest.param(np.clongdouble, np.clongdouble, id="clongdouble"),
92+
]
93+
94+
95+
@pytest.mark.parametrize(("dtype", "expected_dtype"), np_dtype_params)
96+
def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
97+
"""
98+
Test the _to_numpy function with NumPy arrays of NumPy numeric dtypes.
99+
100+
Test both 1-D and 2-D arrays which are not C-contiguous.
101+
"""
102+
# 1-D array that is not C-contiguous
103+
array = np.array([1, 2, 3, 4, 5, 6], dtype=dtype)[::2]
104+
assert array.flags.c_contiguous is False
105+
result = _to_numpy(array)
106+
_check_result(result, expected_dtype)
107+
npt.assert_array_equal(result, array, strict=True)
108+
109+
# 2-D array that is not C-contiguous
110+
array = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype)[::2, ::2]
111+
assert array.flags.c_contiguous is False
112+
result = _to_numpy(array)
113+
_check_result(result, expected_dtype)
114+
npt.assert_array_equal(result, array, strict=True)
115+
116+
117+
########################################################################################
118+
# Test the _to_numpy function with pandas.Series.
119+
#
120+
# In pandas, dtype can be specified by
121+
#
122+
# 1. NumPy dtypes (see above)
123+
# 2. pandas dtypes
124+
# 3. PyArrow dtypes
125+
#
126+
# pandas provides following dtypes:
127+
#
128+
# - Numeric dtypes:
129+
# - Int8, Int16, Int32, Int64
130+
# - UInt8, UInt16, UInt32, UInt64
131+
# - Float32, Float64
132+
# - DatetimeTZDtype
133+
# - PeriodDtype
134+
# - IntervalDtype
135+
# - StringDtype
136+
# - CategoricalDtype
137+
# - SparseDtype
138+
# - BooleanDtype
139+
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
140+
#
141+
# References:
142+
# 1. https://pandas.pydata.org/docs/reference/arrays.html
143+
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
144+
# 3. https://pandas.pydata.org/docs/user_guide/pyarrow.html
145+
########################################################################################
146+
@pytest.mark.parametrize(("dtype", "expected_dtype"), np_dtype_params)
147+
def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
148+
"""
149+
Test the _to_numpy function with pandas.Series of NumPy numeric dtypes.
150+
"""
151+
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
152+
result = _to_numpy(series)
153+
_check_result(result, expected_dtype)
154+
npt.assert_array_equal(result, series)

0 commit comments

Comments
 (0)