diff --git a/pygmt/datasets/load_remote_dataset.py b/pygmt/datasets/load_remote_dataset.py index e0a97a9ea73..ac23714b34e 100644 --- a/pygmt/datasets/load_remote_dataset.py +++ b/pygmt/datasets/load_remote_dataset.py @@ -7,10 +7,7 @@ from typing import Any, Literal, NamedTuple import xarray as xr -from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list, kwargs_to_strings -from pygmt.src import which with contextlib.suppress(ImportError): # rioxarray is needed to register the rio accessor @@ -502,7 +499,6 @@ class GMTRemoteDataset(NamedTuple): } -@kwargs_to_strings(region="sequence") def _load_remote_dataset( name: str, prefix: str, @@ -581,23 +577,9 @@ def _load_remote_dataset( raise GMTInvalidInput(msg) fname = f"@{prefix}_{resolution}_{reg}" - kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[dataset.kind]} - with Session() as lib: - with lib.virtualfile_out(kind=dataset.kind) as voutgrd: - lib.call_module( - module="read", - args=[fname, voutgrd, *build_arg_list(kwdict)], - ) - grid = lib.virtualfile_to_raster( - kind=dataset.kind, outgrid=None, vfname=voutgrd - ) - - # Full path to the grid - source: str | list = which(fname, verbose="q") - if resinfo.tiled: - source = sorted(source)[0] # get first grid for tiled grids - # Manually add source to xarray.DataArray encoding to make the GMT accessors work. - grid.encoding["source"] = source + grid = xr.load_dataarray( + fname, engine="gmt", raster_kind=dataset.kind, region=region + ) # Add some metadata to the grid grid.attrs["description"] = dataset.description diff --git a/pygmt/tests/test_xarray_backend.py b/pygmt/tests/test_xarray_backend.py index 4704bea9bf0..5a373926edf 100644 --- a/pygmt/tests/test_xarray_backend.py +++ b/pygmt/tests/test_xarray_backend.py @@ -40,8 +40,8 @@ def test_xarray_backend_load_dataarray(): def test_xarray_backend_gmt_open_nc_grid(): """ - Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF - grids. + Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF + grid. """ with xr.open_dataarray( "@static_earth_relief.nc", engine="gmt", raster_kind="grid" @@ -52,10 +52,29 @@ def test_xarray_backend_gmt_open_nc_grid(): assert da.gmt.registration is GridRegistration.PIXEL +def test_xarray_backend_gmt_open_nc_grid_with_region_bbox(): + """ + Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray + works to open a netCDF grid over a specific bounding box. + """ + with xr.open_dataarray( + "@static_earth_relief.nc", + engine="gmt", + raster_kind="grid", + region=[-52, -48, -18, -12], + ) as da: + assert da.sizes == {"lat": 6, "lon": 4} + npt.assert_allclose(da.lat, [-17.5, -16.5, -15.5, -14.5, -13.5, -12.5]) + npt.assert_allclose(da.lon, [-51.5, -50.5, -49.5, -48.5]) + assert da.dtype == "float32" + assert da.gmt.gtype is GridType.GEOGRAPHIC + assert da.gmt.registration is GridRegistration.PIXEL + + def test_xarray_backend_gmt_open_tif_image(): """ - Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF - images. + Ensure that passing engine='gmt' to xarray.open_dataarray works to open a GeoTIFF + image. """ with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da: assert da.sizes == {"band": 3, "y": 180, "x": 360} @@ -64,6 +83,22 @@ def test_xarray_backend_gmt_open_tif_image(): assert da.gmt.registration is GridRegistration.PIXEL +def test_xarray_backend_gmt_open_tif_image_with_region_iso(): + """ + Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray + works to open a GeoTIFF image over a specific ISO country code border. + """ + with xr.open_dataarray( + "@earth_day_01d", engine="gmt", raster_kind="image", region="BN" + ) as da: + assert da.sizes == {"band": 3, "lat": 2, "lon": 2} + npt.assert_allclose(da.lat, [5.5, 4.5]) + npt.assert_allclose(da.lon, [114.5, 115.5]) + assert da.dtype == "uint8" + assert da.gmt.gtype is GridType.GEOGRAPHIC + assert da.gmt.registration is GridRegistration.PIXEL + + def test_xarray_backend_gmt_load_grd_grid(): """ Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD @@ -88,9 +123,7 @@ def test_xarray_backend_gmt_read_invalid_kind(): """ with pytest.raises( TypeError, - match=re.escape( - "GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'" - ), + match=re.escape("missing a required argument: 'raster_kind'"), ): xr.open_dataarray("nokind.nc", engine="gmt") diff --git a/pygmt/xarray/backend.py b/pygmt/xarray/backend.py index a95e98983db..24530881fb4 100644 --- a/pygmt/xarray/backend.py +++ b/pygmt/xarray/backend.py @@ -2,13 +2,14 @@ An xarray backend for reading raster grid/image files using the 'gmt' engine. """ +from collections.abc import Sequence from typing import Literal import xarray as xr from pygmt._typing import PathLike from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list +from pygmt.helpers import build_arg_list, kwargs_to_strings from pygmt.src.which import which from xarray.backends import BackendEntrypoint @@ -30,6 +31,9 @@ class GMTBackendEntrypoint(BackendEntrypoint): - ``"grid"``: for reading single-band raster grids - ``"image"``: for reading multi-band raster images + Optionally, you can pass in a ``region`` in the form of a sequence [*xmin*, *xmax*, + *ymin*, *ymax*] or an ISO country code. + Examples -------- Read a single-band netCDF file using ``raster_kind="grid"`` @@ -68,18 +72,45 @@ class GMTBackendEntrypoint(BackendEntrypoint): * band (band) uint8... 1 2 3 Attributes:... long_name: z + + Load a single-band netCDF file using ``raster_kind="grid"`` over a bounding box + region. + + >>> da_grid = xr.load_dataarray( + ... "@tut_bathy.nc", engine="gmt", raster_kind="grid", region=[-64, -62, 32, 33] + ... ) + >>> da_grid + ... + array([[-4369., -4587., -4469., -4409., -4587., -4505., -4403., -4405., + -4466., -4595., -4609., -4608., -4606., -4607., -4607., -4597., + ... + -4667., -4642., -4677., -4795., -4797., -4800., -4803., -4818., + -4820.]], dtype=float32) + Coordinates: + * lat (lat) float64... 32.0 32.08 32.17 32.25 ... 32.83 32.92 33.0 + * lon (lon) float64... -64.0 -63.92 -63.83 ... -62.17 -62.08 -62.0 + Attributes:... + Conventions: CF-1.7 + title: ETOPO5 global topography + history: grdreformat -fg bermuda.grd bermuda.nc=ns + description: /home/elepaio5/data/grids/etopo5.i2 + actual_range: [-4968. -4315.] + long_name: Topography + units: m """ description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT." - open_dataset_parameters = ("filename_or_obj", "raster_kind") + open_dataset_parameters = ("filename_or_obj", "raster_kind", "region") url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html" + @kwargs_to_strings(region="sequence") def open_dataset( # type: ignore[override] self, filename_or_obj: PathLike, *, drop_variables=None, # noqa: ARG002 raster_kind: Literal["grid", "image"], + region: Sequence[float] | str | None = None, # other backend specific keyword arguments # `chunks` and `cache` DO NOT go here, they are handled by xarray ) -> xr.Dataset: @@ -94,6 +125,9 @@ def open_dataset( # type: ignore[override] :gmt-docs:`reference/features.html#grid-file-format`. raster_kind Whether to read the file as a "grid" (single-band) or "image" (multi-band). + region + The subregion of the grid or image to load, in the form of a sequence + [*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code. """ if raster_kind not in {"grid", "image"}: msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'." @@ -101,7 +135,7 @@ def open_dataset( # type: ignore[override] with Session() as lib: with lib.virtualfile_out(kind=raster_kind) as voutfile: - kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]} + kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[raster_kind]} lib.call_module( module="read", args=[filename_or_obj, voutfile, *build_arg_list(kwdict)], @@ -111,9 +145,8 @@ def open_dataset( # type: ignore[override] vfname=voutfile, kind=raster_kind ) # Add "source" encoding - source = which(fname=filename_or_obj) + source: str | list = which(fname=filename_or_obj, verbose="q") raster.encoding["source"] = ( - source[0] if isinstance(source, list) else source + sorted(source)[0] if isinstance(source, list) else source ) - _ = raster.gmt # Load GMTDataArray accessor information return raster.to_dataset()