Skip to content

GMT_IMAGE: Implement the GMT_IMAGE.to_dataarray method for 3-band images #3128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bcf43f0
Wrap GMT's standard data type GMT_IMAGE for images
seisman Mar 18, 2024
a052a1a
Initial implementation of to_dataarray method for _GMT_IMAGE class
weiji14 Mar 20, 2024
56a6d65
Merge branch 'main' into datatypes/gmtimage
seisman Apr 17, 2024
f71e79c
Merge branch 'main' into datatypes/gmtimage
weiji14 Jun 18, 2024
4cce4a2
Small typo fixes and add output type-hint for to_dataarray
weiji14 Jun 18, 2024
e02b650
Fix mypy error using np.array([0, 1, 2]) instead of np.arange
weiji14 Jun 18, 2024
f3d4b1f
Parse name and data_attrs from grid/image header
weiji14 Jun 18, 2024
4390136
Transpose array to (band, y, x) order and add doctest for to_dataarray
weiji14 Jun 20, 2024
5f25669
Set registration and gtype from header
weiji14 Jun 20, 2024
a3c6c14
Print basic shape and padding info in _GMT_IMAGE doctest
weiji14 Jun 20, 2024
5888e10
Only set Conventions = CF-1.7 attribute for NetCDF grid type
weiji14 Jun 20, 2024
798e658
Merge branch 'main' into datatypes/gmtimage
weiji14 Jun 20, 2024
3dbf2f2
Remove rioxarray import
weiji14 Jun 20, 2024
6b860bf
Merge branch 'main' into datatypes/gmtimage
seisman Jul 27, 2024
0bf9368
Merge branch 'main' into datatypes/gmtimage
seisman Sep 19, 2024
7d437be
Use enum for grid ids
seisman Sep 19, 2024
268e34e
Fix the band. Starting from 1
seisman Sep 19, 2024
86765e1
Refactor the tests for images
seisman Sep 19, 2024
86f3ffa
In np.reshape, a is a position-only parameter
seisman Sep 20, 2024
cc28247
Improve tests
seisman Sep 20, 2024
1e2c973
Fix one failing doctest due to xarray changes
seisman Sep 20, 2024
734dc28
The np.reshape's newshape parameter is deprecated
seisman Sep 20, 2024
919dc00
Define grid IDs using IntEnum instead of Enum
seisman Sep 20, 2024
b1eacf1
Pass the new shape as a positional parameter
seisman Sep 20, 2024
aa4fdc9
Fix failing tests
seisman Sep 20, 2024
c87a3ec
One more fix
seisman Sep 20, 2024
a20d8a2
One more fix
seisman Sep 20, 2024
926427b
Simplify a doctest
seisman Sep 20, 2024
c73328e
Improve the tests
seisman Sep 20, 2024
fb97daa
Convert ctypes array to numpy array using np.ctypeslib.as_array
seisman Sep 20, 2024
15b8d53
Fix the incorrect value due to floating number conversion in sphinter…
seisman Sep 20, 2024
8433e78
Merge branch 'ctypesarray' into datatypes/gmtimage
seisman Sep 20, 2024
3e3a6f3
Update the to_dataarray method to match the codes in GMT_GRID
seisman Sep 20, 2024
12ef40a
image data should has uint8 dtype
seisman Sep 20, 2024
f64fbb8
Further improve the tests
seisman Sep 21, 2024
4f2ae48
Merge branch 'main' into datatypes/gmtimage
seisman Sep 24, 2024
d49afed
Add a note that currently only 3-band images are supported
seisman Sep 24, 2024
f70bec0
Merge branch 'main' into datatypes/gmtimage
seisman Sep 28, 2024
2fd13fb
Remove the old GMTGridID enums from pygmt/datatypes/header.py
seisman Sep 28, 2024
9972ba1
A minor fix
seisman Sep 28, 2024
62f0ce0
Set CF-1.7-specific attributes for netCDF formats only
seisman Sep 29, 2024
f715aee
Improve the logic for determine grid/image gtype based on upstream GM…
seisman Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pygmt/datatypes/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def to_dataarray(self) -> xr.DataArray:
title: Produced by grdcut
history: grdcut @earth_relief_01d_p -R-55/-47/-24/-10 -Gstatic_ea...
description: Reduced by Gaussian Cartesian filtering (111.2 km fullwi...
long_name: elevation (m)
actual_range: [190. 981.]
long_name: elevation (m)
>>> da.coords["lon"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
<xarray.DataArray 'lon' (lon: 8)>...
array([-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5])
Expand Down
22 changes: 16 additions & 6 deletions pygmt/datatypes/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,17 @@ def data_attrs(self) -> dict[str, Any]:
GridFormat.NI,
GridFormat.NF,
GridFormat.ND,
}: # Only set the 'Conventions' attribute for netCDF.
}: # Set attributes specific to CF-1.7 conventions
attrs["Conventions"] = "CF-1.7"
attrs["title"] = self.title.decode()
attrs["history"] = self.command.decode()
attrs["description"] = self.remark.decode()
attrs["title"] = self.title.decode()
attrs["history"] = self.command.decode()
attrs["description"] = self.remark.decode()
attrs["actual_range"] = np.array([self.z_min, self.z_max])
long_name, units = _parse_nameunits(self.z_units.decode())
if long_name:
attrs["long_name"] = long_name
if units:
attrs["units"] = units
attrs["actual_range"] = np.array([self.z_min, self.z_max])
return attrs

@property
Expand Down Expand Up @@ -250,6 +250,16 @@ def gtype(self) -> int:
"lon"/"lat" or have units "degrees_east"/"degrees_north", then the grid is
assumed to be geographic.
"""
gtype = 0 # Cartesian by default

dims = self.dims
gtype = 1 if dims[0] == "lat" and dims[1] == "lon" else 0
if dims[0] == "lat" and dims[1] == "lon":
# Check dimensions for grids that following CF-conventions
gtype = 1
elif self.ProjRefPROJ4 is not None:
# Check ProjRefPROJ4 for images imported via GDAL.
# The logic comes from GMT's `gmtlib_read_image_info` function.
projref = self.ProjRefPROJ4.decode()
if "longlat" in projref or "latlong" in projref:
gtype = 1
return gtype
121 changes: 120 additions & 1 deletion pygmt/datatypes/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import ctypes as ctp
from typing import ClassVar

from pygmt.datatypes.grid import _GMT_GRID_HEADER
import numpy as np
import xarray as xr
from pygmt.datatypes.header import _GMT_GRID_HEADER


class _GMT_IMAGE(ctp.Structure): # noqa: N801
Expand Down Expand Up @@ -63,6 +65,8 @@
array([-179.5, -178.5, ..., 178.5, 179.5])
>>> y # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
array([ 89.5, 88.5, ..., -88.5, -89.5])
>>> data.dtype
dtype('uint8')
>>> data.shape
(180, 360, 3)
>>> data.min(), data.max()
Expand Down Expand Up @@ -91,3 +95,118 @@
# Book-keeping variables "hidden" from the API
("hidden", ctp.c_void_p),
]

def to_dataarray(self) -> xr.DataArray:
"""
Convert a _GMT_IMAGE object to an :class:`xarray.DataArray` object.

Returns
-------
dataarray
A :class:`xarray.DataArray` object.

Examples
--------
>>> from pygmt.clib import Session
>>> with Session() as lib:
... with lib.virtualfile_out(kind="image") as voutimg:
... lib.call_module("read", ["@earth_day_01d_p", voutimg, "-Ti"])
... # Read the image from the virtual file
... image = lib.read_virtualfile(voutimg, kind="image")
... # Convert to xarray.DataArray and use it later
... da = image.contents.to_dataarray()
>>> da # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
<xarray.DataArray 'z' (band: 3, y: 180, x: 360)>...
array([[[ 10, 10, 10, ..., 10, 10, 10],
[ 10, 10, 10, ..., 10, 10, 10],
[ 10, 10, 10, ..., 10, 10, 10],
...,
[192, 193, 193, ..., 193, 192, 191],
[204, 206, 206, ..., 205, 206, 204],
[208, 210, 210, ..., 210, 210, 208]],
<BLANKLINE>
[[ 10, 10, 10, ..., 10, 10, 10],
[ 10, 10, 10, ..., 10, 10, 10],
[ 10, 10, 10, ..., 10, 10, 10],
...,
[186, 187, 188, ..., 187, 186, 185],
[196, 198, 198, ..., 197, 197, 196],
[199, 201, 201, ..., 201, 202, 199]],
<BLANKLINE>
[[ 51, 51, 51, ..., 51, 51, 51],
[ 51, 51, 51, ..., 51, 51, 51],
[ 51, 51, 51, ..., 51, 51, 51],
...,
[177, 179, 179, ..., 178, 177, 177],
[185, 187, 187, ..., 187, 186, 185],
[189, 191, 191, ..., 191, 191, 189]]], dtype=uint8)
Coordinates:
* y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5
* x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
* band (band) uint8... 1 2 3
Attributes:
long_name: z

>>> da.coords["x"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
<xarray.DataArray 'x' (x: 360)>...
array([-179.5, -178.5, -177.5, ..., 177.5, 178.5, 179.5])
Coordinates:
* x (x) float64... -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5
Attributes:
long_name: x
axis: X
actual_range: [-180. 180.]
>>> da.coords["y"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
<xarray.DataArray 'y' (y: 180)>...
array([ 89.5, 88.5, 87.5, 86.5, ..., -87.5, -88.5, -89.5])
Coordinates:
* y (y) float64... 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5
Attributes:
long_name: y
axis: Y
actual_range: [-90. 90.]
>>> da.gmt.registration, da.gmt.gtype
(1, 1)
"""
# The image header
header = self.header.contents

if header.n_bands != 3:
msg = (

Check warning on line 175 in pygmt/datatypes/image.py

View check run for this annotation

Codecov / codecov/patch

pygmt/datatypes/image.py#L175

Added line #L175 was not covered by tests
f"The raster image has {header.n_bands} band(s). "
"Currently only 3-band images are supported. "
"Please consider submitting a feature request to us. "
)
raise NotImplementedError(msg)

Check warning on line 180 in pygmt/datatypes/image.py

View check run for this annotation

Codecov / codecov/patch

pygmt/datatypes/image.py#L180

Added line #L180 was not covered by tests

# Get dimensions and their attributes from the header.
dims, dim_attrs = header.dims, header.dim_attrs
# The coordinates, given as a tuple of the form (dims, data, attrs)
x = np.ctypeslib.as_array(self.x, shape=(header.n_columns,)).copy()
y = np.ctypeslib.as_array(self.y, shape=(header.n_rows,)).copy()
coords = [
(dims[0], y, dim_attrs[0]),
(dims[1], x, dim_attrs[1]),
("band", np.array([1, 2, 3], dtype=np.uint8), None),
]

# Get DataArray without padding
data = np.ctypeslib.as_array(
self.data, shape=(header.my, header.mx, header.n_bands)
).copy()
pad = header.pad[:]
data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]

# Create the xarray.DataArray object
image = xr.DataArray(
data=data,
coords=coords,
name=header.name,
attrs=header.data_attrs,
Comment on lines +204 to +205
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image name is currently hardcoded to z, is that ok for an RGB image?

@property
def name(self) -> str:
"""
Name of the grid.
"""
return "z"

The attrs fields might need some work. I'm getting 'actual_range': array([ 1.79769313e+308, -1.79769313e+308])} when loading the @earth_day_01d image.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree they make no sense, but they're consistent with the behavior in GMT.

gmt grdinfo @earth_day_01d
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Title: Grid imported via GDAL
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Command:
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Remark:
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Pixel node registration used [Geographic grid]
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Grid file format: gd = Import/export through GDAL
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: x_min: -180 x_max: 180 x_inc: 1 name: x n_columns: 360
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: y_min: -90 y_max: 90 y_inc: 1 name: y n_rows: 180
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: v_min: 1.79769313486e+308 v_max: -1.79769313486e+308 name: z
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: scale_factor: 1 add_offset: 0
/Users/seisman/.gmt/server/earth/earth_day/earth_day_01d_p.tif: Default CPT:
+proj=longlat +R=6378137 +no_defs

The GMT's image support was likely added by Joaquim so that you may ping him for more information.

).transpose("band", dims[0], dims[1])

# Set GMT accessors.
# Must put at the end, otherwise info gets lost after certain image operations.
image.gmt.registration = header.registration
image.gmt.gtype = header.gtype
return image
88 changes: 60 additions & 28 deletions pygmt/tests/test_clib_read_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
import xarray as xr
Expand All @@ -14,7 +15,7 @@
from pygmt.src import which

try:
import rioxarray # noqa: F401
import rioxarray

_HAS_RIOXARRAY = True
except ImportError:
Expand All @@ -29,6 +30,18 @@ def fixture_expected_xrgrid():
return load_dataarray(which("@static_earth_relief.nc"))


@pytest.fixture(scope="module", name="expected_xrimage")
def fixture_expected_xrimage():
"""
The expected xr.DataArray object for the @earth_day_01d_p file.
"""
if _HAS_RIOXARRAY:
with rioxarray.open_rasterio(which("@earth_day_01d_p")) as da:
dataarray = da.load().drop_vars("spatial_ref")
return dataarray
return None


def test_clib_read_data_dataset():
"""
Test the Session.read_data method for datasets.
Expand Down Expand Up @@ -98,56 +111,63 @@ def test_clib_read_data_grid_two_steps(expected_xrgrid):

# Read the data
lib.read_data(infile, kind="grid", mode="GMT_DATA_ONLY", data=data_ptr)

# Full check
xrgrid = data_ptr.contents.to_dataarray()
xr.testing.assert_equal(xrgrid, expected_xrgrid)


def test_clib_read_data_grid_actual_image():
def test_clib_read_data_grid_actual_image(expected_xrimage):
"""
Test the Session.read_data method for grid, but actually the file is an image.
"""
with Session() as lib:
data_ptr = lib.read_data(
"@earth_day_01d_p", kind="grid", mode="GMT_CONTAINER_AND_DATA"
)
image = data_ptr.contents
header = image.header.contents
assert header.n_rows == 180
assert header.n_columns == 360
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
image = lib.read_data("@earth_day_01d_p", kind="grid").contents
# Explicitly check n_bands. Only one band is read for 3-band images.
assert header.n_bands == 1
assert image.header.contents.n_bands == 1

xrimage = image.to_dataarray()
assert xrimage.shape == (180, 360)
assert xrimage.coords["x"].data.min() == -179.5
assert xrimage.coords["x"].data.max() == 179.5
assert xrimage.coords["y"].data.min() == -89.5
assert xrimage.coords["y"].data.max() == 89.5
assert xrimage.data.min() == 10.0
assert xrimage.data.max() == 255.0
# Data are stored as uint8 in images but are converted to float32 when reading
# into a GMT_GRID container.
assert xrimage.data.dtype == np.float32

if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
xrimage = image.to_dataarray()
expected_xrimage = xr.open_dataarray(
which("@earth_day_01d_p"), engine="rasterio"
)
assert expected_xrimage.band.size == 3 # 3-band image.
xr.testing.assert_equal(
xrimage,
expected_xrimage.isel(band=0)
.drop_vars(["band", "spatial_ref"])
.sortby("y"),
expected_xrimage.isel(band=0).drop_vars(["band"]).sortby("y"),
)


# Note: Simplify the tests for images after GMT_IMAGE.to_dataarray() is implemented.
def test_clib_read_data_image():
def test_clib_read_data_image(expected_xrimage):
"""
Test the Session.read_data method for images.
"""
with Session() as lib:
image = lib.read_data("@earth_day_01d_p", kind="image").contents
header = image.header.contents
assert header.n_rows == 180
assert header.n_columns == 360
assert header.n_bands == 3
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
assert image.data

xrimage = image.to_dataarray()
assert xrimage.shape == (3, 180, 360)
assert xrimage.coords["x"].data.min() == -179.5
assert xrimage.coords["x"].data.max() == 179.5
assert xrimage.coords["y"].data.min() == -89.5
assert xrimage.coords["y"].data.max() == 89.5
assert xrimage.data.min() == 10
assert xrimage.data.max() == 255
assert xrimage.data.dtype == np.uint8

if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
xr.testing.assert_equal(xrimage, expected_xrimage)


def test_clib_read_data_image_two_steps():
def test_clib_read_data_image_two_steps(expected_xrimage):
"""
Test the Session.read_data method for images in two steps, first reading the header
and then the data.
Expand All @@ -166,7 +186,19 @@ def test_clib_read_data_image_two_steps():

# Read the data
lib.read_data(infile, kind="image", mode="GMT_DATA_ONLY", data=data_ptr)
assert image.data

xrimage = image.to_dataarray()
assert xrimage.shape == (3, 180, 360)
assert xrimage.coords["x"].data.min() == -179.5
assert xrimage.coords["x"].data.max() == 179.5
assert xrimage.coords["y"].data.min() == -89.5
assert xrimage.coords["y"].data.max() == 89.5
assert xrimage.data.min() == 10
assert xrimage.data.max() == 255
assert xrimage.data.dtype == np.uint8

if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
xr.testing.assert_equal(xrimage, expected_xrimage)


def test_clib_read_data_fails():
Expand Down
Loading