Skip to content

Allow passing region to GMTBackendEntrypoint.open_dataset #3932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions pygmt/datasets/load_remote_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from typing import Any, Literal, NamedTuple

import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src import which
from pygmt.helpers import kwargs_to_strings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 503, @kwargs_to_strings(region="sequence") can be removed, since the region parameter, either a sequence or a string, can be directly passed to xr.load_dataarray.


with contextlib.suppress(ImportError):
# rioxarray is needed to register the rio accessor
Expand Down Expand Up @@ -581,22 +579,9 @@ def _load_remote_dataset(
raise GMTInvalidInput(msg)

fname = f"@{prefix}_{resolution}_{reg}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of error messages like:

Error: h [ERROR]: Tile @S90W180.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W150.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W120.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W090.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W060.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90W030.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90E000.earth_age_01m_g.nc not found!
Error: h [ERROR]: Tile @S90E030.earth_age_01m_g.nc not found!

This is because, in the GMT backend, we use something like which("@earth_age_01m_g") to get the file path, which doesn't work well for tiled grids.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we used to do this:

    # Full path to the grid if not tiled grids.
    source = which(fname, download="a") if not resinfo.tiled else None
    # Manually add source to xarray.DataArray encoding to make the GMT accessors work.
    if source:
        grid.encoding["source"] = source

i.e. only add the source for non-tiled grids, so that the accessor's which call doesn't report this error. I'm thinking if it's possible to either 1) silence the which call (does verbose="q" work?), or 2) add some heuristic/logic to determine whether the source is a tiled grid before calling which in GMTBackendEntrypoint

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking if it's possible to either 1) silence the which call (does verbose="q" work?), or 2) add some heuristic/logic to determine whether the source is a tiled grid before calling which in GMTBackendEntrypoint

I think either works. Perhaps verbose="q" is easier?

Copy link
Member Author

@weiji14 weiji14 May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit 5557b33.

Edit: Also just realized that verbose="q" was suggested before in #524 (comment).

kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[dataset.kind]}
with Session() as lib:
with lib.virtualfile_out(kind=dataset.kind) as voutgrd:
lib.call_module(
module="read",
args=[fname, voutgrd, *build_arg_list(kwdict)],
)
grid = lib.virtualfile_to_raster(
kind=dataset.kind, outgrid=None, vfname=voutgrd
)

# Full path to the grid if not tiled grids.
source = which(fname, download="a") if not resinfo.tiled else None
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
if source:
grid.encoding["source"] = source
grid = xr.load_dataarray(
fname, engine="gmt", raster_kind=dataset.kind, region=region
)

# Add some metadata to the grid
grid.attrs["description"] = dataset.description
Expand Down
37 changes: 24 additions & 13 deletions pygmt/tests/test_xarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def test_xarray_accessor_sliced_datacube():
Path(fname).unlink()


def test_xarray_accessor_grid_source_file_not_exist():
def test_xarray_accessor_tiled_grid_slice_and_add():
"""
Check that the accessor fallbacks to the default registration and gtype when the
grid source file (i.e., grid.encoding["source"]) doesn't exist.
Check that the accessor works to get the registration and gtype when the grid source
file is from a tiled grid, that slicing doesn't affect registration/gtype, but math
operations do return the default registration/gtype as a fallback.

Unit test to track https://github.com/GenericMappingTools/pygmt/issues/524
"""
# Load the 05m earth relief grid, which is stored as tiles.
grid = load_earth_relief(
Expand All @@ -144,17 +147,25 @@ def test_xarray_accessor_grid_source_file_not_exist():
# Registration and gtype are correct.
assert grid.gmt.registration is GridRegistration.PIXEL
assert grid.gmt.gtype is GridType.GEOGRAPHIC
# The source grid file is undefined.
assert grid.encoding.get("source") is None
# The source grid file for tiled grids is the first tile
assert grid.encoding["source"].endswith("S90W180.earth_relief_05m_p.nc")

# For a sliced grid, fallback to default registration and gtype, because the source
# grid file doesn't exist.
# For a sliced grid, ensure we don't fallback to the default registration (gridline)
# and gtype (cartesian), because the source grid file should still exist.
sliced_grid = grid[1:3, 1:3]
assert sliced_grid.gmt.registration is GridRegistration.GRIDLINE
assert sliced_grid.gmt.gtype is GridType.CARTESIAN

# Still possible to manually set registration and gtype.
sliced_grid.gmt.registration = GridRegistration.PIXEL
sliced_grid.gmt.gtype = GridType.GEOGRAPHIC
assert sliced_grid.encoding["source"].endswith("S90W180.earth_relief_05m_p.nc")
assert sliced_grid.gmt.registration is GridRegistration.PIXEL
assert sliced_grid.gmt.gtype is GridType.GEOGRAPHIC

# For a grid that underwent mathematical operations, fallback to default
# registration and gtype, because the source grid file doesn't exist.
added_grid = sliced_grid + 9
assert added_grid.encoding == {}
assert added_grid.gmt.registration is GridRegistration.GRIDLINE
assert added_grid.gmt.gtype is GridType.CARTESIAN

# Still possible to manually set registration and gtype.
added_grid.gmt.registration = GridRegistration.PIXEL
added_grid.gmt.gtype = GridType.GEOGRAPHIC
assert added_grid.gmt.registration is GridRegistration.PIXEL
assert added_grid.gmt.gtype is GridType.GEOGRAPHIC
47 changes: 40 additions & 7 deletions pygmt/tests/test_xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def test_xarray_backend_load_dataarray():

def test_xarray_backend_gmt_open_nc_grid():
"""
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF
grids.
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF
grid.
"""
with xr.open_dataarray(
"@static_earth_relief.nc", engine="gmt", raster_kind="grid"
Expand All @@ -52,10 +52,29 @@ def test_xarray_backend_gmt_open_nc_grid():
assert da.gmt.registration is GridRegistration.PIXEL


def test_xarray_backend_gmt_open_nc_grid_with_region_bbox():
"""
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
works to open a netCDF grid over a specific bounding box.
"""
with xr.open_dataarray(
"@static_earth_relief.nc",
engine="gmt",
raster_kind="grid",
region=[-52, -48, -18, -12],
) as da:
assert da.sizes == {"lat": 6, "lon": 4}
npt.assert_allclose(da.lat, [-17.5, -16.5, -15.5, -14.5, -13.5, -12.5])
npt.assert_allclose(da.lon, [-51.5, -50.5, -49.5, -48.5])
assert da.dtype == "float32"
assert da.gmt.gtype is GridType.GEOGRAPHIC
assert da.gmt.registration is GridRegistration.PIXEL


def test_xarray_backend_gmt_open_tif_image():
"""
Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF
images.
Ensure that passing engine='gmt' to xarray.open_dataarray works to open a GeoTIFF
image.
"""
with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da:
assert da.sizes == {"band": 3, "y": 180, "x": 360}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coordinate names are y/x when region=None, but lat/lon when region is not None at L90 below. Need to fix this inconsistency.

Expand All @@ -64,6 +83,22 @@ def test_xarray_backend_gmt_open_tif_image():
assert da.gmt.registration is GridRegistration.PIXEL


def test_xarray_backend_gmt_open_tif_image_with_region_iso():
"""
Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray
works to open a GeoTIFF image over a specific ISO country code border.
"""
with xr.open_dataarray(
"@earth_day_01d", engine="gmt", raster_kind="image", region="BN"
) as da:
assert da.sizes == {"band": 3, "lat": 2, "lon": 2}
npt.assert_allclose(da.lat, [5.5, 4.5])
npt.assert_allclose(da.lon, [114.5, 115.5])
assert da.dtype == "uint8"
assert da.gmt.gtype is GridType.GEOGRAPHIC
assert da.gmt.registration is GridRegistration.PIXEL


def test_xarray_backend_gmt_load_grd_grid():
"""
Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD
Expand All @@ -88,9 +123,7 @@ def test_xarray_backend_gmt_read_invalid_kind():
"""
with pytest.raises(
TypeError,
match=re.escape(
"GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'"
),
match=re.escape("missing a required argument: 'raster_kind'"),
):
xr.open_dataarray("nokind.nc", engine="gmt")

Expand Down
43 changes: 38 additions & 5 deletions pygmt/xarray/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
An xarray backend for reading raster grid/image files using the 'gmt' engine.
"""

from collections.abc import Sequence
from typing import Literal

import xarray as xr
from pygmt._typing import PathLike
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import build_arg_list
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src.which import which
from xarray.backends import BackendEntrypoint

Expand All @@ -30,6 +31,9 @@ class GMTBackendEntrypoint(BackendEntrypoint):
- ``"grid"``: for reading single-band raster grids
- ``"image"``: for reading multi-band raster images

Optionally, you can pass in a `region`in the form of a sequence [*xmin*, *xmax*,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Optionally, you can pass in a `region`in the form of a sequence [*xmin*, *xmax*,
Optionally, you can pass in a ``region`` in the form of a sequence [*xmin*, *xmax*,

*ymin*, *ymax*] or an ISO country code.

Examples
--------
Read a single-band netCDF file using ``raster_kind="grid"``
Expand Down Expand Up @@ -68,18 +72,45 @@ class GMTBackendEntrypoint(BackendEntrypoint):
* band (band) uint8... 1 2 3
Attributes:...
long_name: z

Load a single-band netCDF file using ``raster_kind="grid"`` over a bounding box
region.

>>> da_grid = xr.load_dataarray(
... "@tut_bathy.nc", engine="gmt", raster_kind="grid", region=[-64, -62, 32, 33]
... )
>>> da_grid
<xarray.DataArray 'z' (lat: 13, lon: 25)>...
array([[-4369., -4587., -4469., -4409., -4587., -4505., -4403., -4405.,
-4466., -4595., -4609., -4608., -4606., -4607., -4607., -4597.,
...
-4667., -4642., -4677., -4795., -4797., -4800., -4803., -4818.,
-4820.]], dtype=float32)
Coordinates:
* lat (lat) float64... 32.0 32.08 32.17 32.25 ... 32.83 32.92 33.0
* lon (lon) float64... -64.0 -63.92 -63.83 ... -62.17 -62.08 -62.0
Attributes:...
Conventions: CF-1.7
title: ETOPO5 global topography
history: grdreformat -fg bermuda.grd bermuda.nc=ns
description: /home/elepaio5/data/grids/etopo5.i2
actual_range: [-4968. -4315.]
long_name: Topography
units: m
"""

description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT."
open_dataset_parameters = ("filename_or_obj", "raster_kind")
open_dataset_parameters = ("filename_or_obj", "raster_kind", "region")
url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html"

@kwargs_to_strings(region="sequence")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been think if we should avoid using the @kwargs_to_strings decorator in new functions/methods, and instead write a new function like seqjoin which does exactly the same thing.

def open_dataset( # type: ignore[override]
self,
filename_or_obj: PathLike,
*,
drop_variables=None, # noqa: ARG002
raster_kind: Literal["grid", "image"],
region: Sequence[float] | str | None = None,
# other backend specific keyword arguments
# `chunks` and `cache` DO NOT go here, they are handled by xarray
) -> xr.Dataset:
Expand All @@ -94,14 +125,17 @@ def open_dataset( # type: ignore[override]
:gmt-docs:`reference/features.html#grid-file-format`.
raster_kind
Whether to read the file as a "grid" (single-band) or "image" (multi-band).
region
Optional. The subregion of the grid or image to load, in the form of a
sequence [*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code.
Comment on lines +129 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Optional. The subregion of the grid or image to load, in the form of a
sequence [*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code.
The subregion of the grid or image to load, in the form of a sequence
[*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code.

"""
if raster_kind not in {"grid", "image"}:
msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'."
raise GMTInvalidInput(msg)

with Session() as lib:
with lib.virtualfile_out(kind=raster_kind) as voutfile:
kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]}
kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[raster_kind]}
lib.call_module(
module="read",
args=[filename_or_obj, voutfile, *build_arg_list(kwdict)],
Expand All @@ -111,9 +145,8 @@ def open_dataset( # type: ignore[override]
vfname=voutfile, kind=raster_kind
)
# Add "source" encoding
source = which(fname=filename_or_obj)
source: str | list = which(fname=filename_or_obj, verbose="q")
raster.encoding["source"] = (
source[0] if isinstance(source, list) else source
)
_ = raster.gmt # Load GMTDataArray accessor information
return raster.to_dataset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it's likely that the accessor information will be lost when converting via to_dataset.