Skip to content

Commit 6606845

Browse files
seismanweiji14
andauthored
GMT_IMAGE: Implement the GMT_IMAGE.to_dataarray method for 3-band images (#3128)
Co-authored-by: Wei Ji <[email protected]>
1 parent cf48764 commit 6606845

File tree

4 files changed

+197
-36
lines changed

4 files changed

+197
-36
lines changed

pygmt/datatypes/grid.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def to_dataarray(self) -> xr.DataArray:
139139
title: Produced by grdcut
140140
history: grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_ea...
141141
description: Reduced by Gaussian Cartesian filtering (111.2 km fullwi...
142-
long_name: elevation (m)
143142
actual_range: [190. 981.]
143+
long_name: elevation (m)
144144
>>> da.coords["lon"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
145145
<xarray.DataArray 'lon' (lon: 8)>...
146146
array([-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5])

pygmt/datatypes/header.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,17 @@ def data_attrs(self) -> dict[str, Any]:
210210
GridFormat.NI,
211211
GridFormat.NF,
212212
GridFormat.ND,
213-
}: # Only set the 'Conventions' attribute for netCDF.
213+
}: # Set attributes specific to CF-1.7 conventions
214214
attrs["Conventions"] = "CF-1.7"
215-
attrs["title"] = self.title.decode()
216-
attrs["history"] = self.command.decode()
217-
attrs["description"] = self.remark.decode()
215+
attrs["title"] = self.title.decode()
216+
attrs["history"] = self.command.decode()
217+
attrs["description"] = self.remark.decode()
218+
attrs["actual_range"] = np.array([self.z_min, self.z_max])
218219
long_name, units = _parse_nameunits(self.z_units.decode())
219220
if long_name:
220221
attrs["long_name"] = long_name
221222
if units:
222223
attrs["units"] = units
223-
attrs["actual_range"] = np.array([self.z_min, self.z_max])
224224
return attrs
225225

226226
@property
@@ -250,6 +250,16 @@ def gtype(self) -> int:
250250
"lon"/"lat" or have units "degrees_east"/"degrees_north", then the grid is
251251
assumed to be geographic.
252252
"""
253+
gtype = 0 # Cartesian by default
254+
253255
dims = self.dims
254-
gtype = 1 if dims[0] == "lat" and dims[1] == "lon" else 0
256+
if dims[0] == "lat" and dims[1] == "lon":
257+
# Check dimensions for grids that following CF-conventions
258+
gtype = 1
259+
elif self.ProjRefPROJ4 is not None:
260+
# Check ProjRefPROJ4 for images imported via GDAL.
261+
# The logic comes from GMT's `gmtlib_read_image_info` function.
262+
projref = self.ProjRefPROJ4.decode()
263+
if "longlat" in projref or "latlong" in projref:
264+
gtype = 1
255265
return gtype

pygmt/datatypes/image.py

+120-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import ctypes as ctp
66
from typing import ClassVar
77

8-
from pygmt.datatypes.grid import _GMT_GRID_HEADER
8+
import numpy as np
9+
import xarray as xr
10+
from pygmt.datatypes.header import _GMT_GRID_HEADER
911

1012

1113
class _GMT_IMAGE(ctp.Structure): # noqa: N801
@@ -63,6 +65,8 @@ class _GMT_IMAGE(ctp.Structure): # noqa: N801
6365
array([-179.5, -178.5, ..., 178.5, 179.5])
6466
>>> y # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
6567
array([ 89.5, 88.5, ..., -88.5, -89.5])
68+
>>> data.dtype
69+
dtype('uint8')
6670
>>> data.shape
6771
(180, 360, 3)
6872
>>> data.min(), data.max()
@@ -91,3 +95,118 @@ class _GMT_IMAGE(ctp.Structure): # noqa: N801
9195
# Book-keeping variables "hidden" from the API
9296
("hidden", ctp.c_void_p),
9397
]
98+
99+
def to_dataarray(self) -> xr.DataArray:
100+
"""
101+
Convert a _GMT_IMAGE object to an :class:`xarray.DataArray` object.
102+
103+
Returns
104+
-------
105+
dataarray
106+
A :class:`xarray.DataArray` object.
107+
108+
Examples
109+
--------
110+
>>> from pygmt.clib import Session
111+
>>> with Session() as lib:
112+
... with lib.virtualfile_out(kind="image") as voutimg:
113+
... lib.call_module("read", ["@earth_day_01d_p", voutimg, "-Ti"])
114+
... # Read the image from the virtual file
115+
... image = lib.read_virtualfile(voutimg, kind="image")
116+
... # Convert to xarray.DataArray and use it later
117+
... da = image.contents.to_dataarray()
118+
>>> da # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
119+
<xarray.DataArray 'z' (band: 3, y: 180, x: 360)>...
120+
array([[[ 10, 10, 10, ..., 10, 10, 10],
121+
[ 10, 10, 10, ..., 10, 10, 10],
122+
[ 10, 10, 10, ..., 10, 10, 10],
123+
...,
124+
[192, 193, 193, ..., 193, 192, 191],
125+
[204, 206, 206, ..., 205, 206, 204],
126+
[208, 210, 210, ..., 210, 210, 208]],
127+
<BLANKLINE>
128+
[[ 10, 10, 10, ..., 10, 10, 10],
129+
[ 10, 10, 10, ..., 10, 10, 10],
130+
[ 10, 10, 10, ..., 10, 10, 10],
131+
...,
132+
[186, 187, 188, ..., 187, 186, 185],
133+
[196, 198, 198, ..., 197, 197, 196],
134+
[199, 201, 201, ..., 201, 202, 199]],
135+
<BLANKLINE>
136+
[[ 51, 51, 51, ..., 51, 51, 51],
137+
[ 51, 51, 51, ..., 51, 51, 51],
138+
[ 51, 51, 51, ..., 51, 51, 51],
139+
...,
140+
[177, 179, 179, ..., 178, 177, 177],
141+
[185, 187, 187, ..., 187, 186, 185],
142+
[189, 191, 191, ..., 191, 191, 189]]], dtype=uint8)
143+
Coordinates:
144+
* y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5
145+
* x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
146+
* band (band) uint8... 1 2 3
147+
Attributes:
148+
long_name: z
149+
150+
>>> da.coords["x"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
151+
<xarray.DataArray 'x' (x: 360)>...
152+
array([-179.5, -178.5, -177.5, ..., 177.5, 178.5, 179.5])
153+
Coordinates:
154+
* x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
155+
Attributes:
156+
long_name: x
157+
axis: X
158+
actual_range: [-180. 180.]
159+
>>> da.coords["y"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
160+
<xarray.DataArray 'y' (y: 180)>...
161+
array([ 89.5, 88.5, 87.5, 86.5, ..., -87.5, -88.5, -89.5])
162+
Coordinates:
163+
* y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5
164+
Attributes:
165+
long_name: y
166+
axis: Y
167+
actual_range: [-90. 90.]
168+
>>> da.gmt.registration, da.gmt.gtype
169+
(1, 1)
170+
"""
171+
# The image header
172+
header = self.header.contents
173+
174+
if header.n_bands != 3:
175+
msg = (
176+
f"The raster image has {header.n_bands} band(s). "
177+
"Currently only 3-band images are supported. "
178+
"Please consider submitting a feature request to us. "
179+
)
180+
raise NotImplementedError(msg)
181+
182+
# Get dimensions and their attributes from the header.
183+
dims, dim_attrs = header.dims, header.dim_attrs
184+
# The coordinates, given as a tuple of the form (dims, data, attrs)
185+
x = np.ctypeslib.as_array(self.x, shape=(header.n_columns,)).copy()
186+
y = np.ctypeslib.as_array(self.y, shape=(header.n_rows,)).copy()
187+
coords = [
188+
(dims[0], y, dim_attrs[0]),
189+
(dims[1], x, dim_attrs[1]),
190+
("band", np.array([1, 2, 3], dtype=np.uint8), None),
191+
]
192+
193+
# Get DataArray without padding
194+
data = np.ctypeslib.as_array(
195+
self.data, shape=(header.my, header.mx, header.n_bands)
196+
).copy()
197+
pad = header.pad[:]
198+
data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]
199+
200+
# Create the xarray.DataArray object
201+
image = xr.DataArray(
202+
data=data,
203+
coords=coords,
204+
name=header.name,
205+
attrs=header.data_attrs,
206+
).transpose("band", dims[0], dims[1])
207+
208+
# Set GMT accessors.
209+
# Must put at the end, otherwise info gets lost after certain image operations.
210+
image.gmt.registration = header.registration
211+
image.gmt.gtype = header.gtype
212+
return image

pygmt/tests/test_clib_read_data.py

+60-28
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pathlib import Path
66

7+
import numpy as np
78
import pandas as pd
89
import pytest
910
import xarray as xr
@@ -14,7 +15,7 @@
1415
from pygmt.src import which
1516

1617
try:
17-
import rioxarray # noqa: F401
18+
import rioxarray
1819

1920
_HAS_RIOXARRAY = True
2021
except ImportError:
@@ -29,6 +30,18 @@ def fixture_expected_xrgrid():
2930
return load_dataarray(which("@static_earth_relief.nc"))
3031

3132

33+
@pytest.fixture(scope="module", name="expected_xrimage")
34+
def fixture_expected_xrimage():
35+
"""
36+
The expected xr.DataArray object for the @earth_day_01d_p file.
37+
"""
38+
if _HAS_RIOXARRAY:
39+
with rioxarray.open_rasterio(which("@earth_day_01d_p")) as da:
40+
dataarray = da.load().drop_vars("spatial_ref")
41+
return dataarray
42+
return None
43+
44+
3245
def test_clib_read_data_dataset():
3346
"""
3447
Test the Session.read_data method for datasets.
@@ -98,56 +111,63 @@ def test_clib_read_data_grid_two_steps(expected_xrgrid):
98111

99112
# Read the data
100113
lib.read_data(infile, kind="grid", mode="GMT_DATA_ONLY", data=data_ptr)
114+
115+
# Full check
101116
xrgrid = data_ptr.contents.to_dataarray()
102117
xr.testing.assert_equal(xrgrid, expected_xrgrid)
103118

104119

105-
def test_clib_read_data_grid_actual_image():
120+
def test_clib_read_data_grid_actual_image(expected_xrimage):
106121
"""
107122
Test the Session.read_data method for grid, but actually the file is an image.
108123
"""
109124
with Session() as lib:
110-
data_ptr = lib.read_data(
111-
"@earth_day_01d_p", kind="grid", mode="GMT_CONTAINER_AND_DATA"
112-
)
113-
image = data_ptr.contents
114-
header = image.header.contents
115-
assert header.n_rows == 180
116-
assert header.n_columns == 360
117-
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
125+
image = lib.read_data("@earth_day_01d_p", kind="grid").contents
118126
# Explicitly check n_bands. Only one band is read for 3-band images.
119-
assert header.n_bands == 1
127+
assert image.header.contents.n_bands == 1
128+
129+
xrimage = image.to_dataarray()
130+
assert xrimage.shape == (180, 360)
131+
assert xrimage.coords["x"].data.min() == -179.5
132+
assert xrimage.coords["x"].data.max() == 179.5
133+
assert xrimage.coords["y"].data.min() == -89.5
134+
assert xrimage.coords["y"].data.max() == 89.5
135+
assert xrimage.data.min() == 10.0
136+
assert xrimage.data.max() == 255.0
137+
# Data are stored as uint8 in images but are converted to float32 when reading
138+
# into a GMT_GRID container.
139+
assert xrimage.data.dtype == np.float32
120140

121141
if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
122-
xrimage = image.to_dataarray()
123-
expected_xrimage = xr.open_dataarray(
124-
which("@earth_day_01d_p"), engine="rasterio"
125-
)
126142
assert expected_xrimage.band.size == 3 # 3-band image.
127143
xr.testing.assert_equal(
128144
xrimage,
129-
expected_xrimage.isel(band=0)
130-
.drop_vars(["band", "spatial_ref"])
131-
.sortby("y"),
145+
expected_xrimage.isel(band=0).drop_vars(["band"]).sortby("y"),
132146
)
133147

134148

135-
# Note: Simplify the tests for images after GMT_IMAGE.to_dataarray() is implemented.
136-
def test_clib_read_data_image():
149+
def test_clib_read_data_image(expected_xrimage):
137150
"""
138151
Test the Session.read_data method for images.
139152
"""
140153
with Session() as lib:
141154
image = lib.read_data("@earth_day_01d_p", kind="image").contents
142-
header = image.header.contents
143-
assert header.n_rows == 180
144-
assert header.n_columns == 360
145-
assert header.n_bands == 3
146-
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
147-
assert image.data
155+
156+
xrimage = image.to_dataarray()
157+
assert xrimage.shape == (3, 180, 360)
158+
assert xrimage.coords["x"].data.min() == -179.5
159+
assert xrimage.coords["x"].data.max() == 179.5
160+
assert xrimage.coords["y"].data.min() == -89.5
161+
assert xrimage.coords["y"].data.max() == 89.5
162+
assert xrimage.data.min() == 10
163+
assert xrimage.data.max() == 255
164+
assert xrimage.data.dtype == np.uint8
165+
166+
if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
167+
xr.testing.assert_equal(xrimage, expected_xrimage)
148168

149169

150-
def test_clib_read_data_image_two_steps():
170+
def test_clib_read_data_image_two_steps(expected_xrimage):
151171
"""
152172
Test the Session.read_data method for images in two steps, first reading the header
153173
and then the data.
@@ -166,7 +186,19 @@ def test_clib_read_data_image_two_steps():
166186

167187
# Read the data
168188
lib.read_data(infile, kind="image", mode="GMT_DATA_ONLY", data=data_ptr)
169-
assert image.data
189+
190+
xrimage = image.to_dataarray()
191+
assert xrimage.shape == (3, 180, 360)
192+
assert xrimage.coords["x"].data.min() == -179.5
193+
assert xrimage.coords["x"].data.max() == 179.5
194+
assert xrimage.coords["y"].data.min() == -89.5
195+
assert xrimage.coords["y"].data.max() == 89.5
196+
assert xrimage.data.min() == 10
197+
assert xrimage.data.max() == 255
198+
assert xrimage.data.dtype == np.uint8
199+
200+
if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
201+
xr.testing.assert_equal(xrimage, expected_xrimage)
170202

171203

172204
def test_clib_read_data_fails():

0 commit comments

Comments
 (0)