diff --git a/pygmt/datasets/load_remote_dataset.py b/pygmt/datasets/load_remote_dataset.py index c6f0471b8ee..b86174dd54d 100644 --- a/pygmt/datasets/load_remote_dataset.py +++ b/pygmt/datasets/load_remote_dataset.py @@ -7,10 +7,8 @@ from typing import Any, Literal, NamedTuple import xarray as xr -from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list, kwargs_to_strings -from pygmt.src import which +from pygmt.helpers import kwargs_to_strings with contextlib.suppress(ImportError): # rioxarray is needed to register the rio accessor @@ -581,22 +579,9 @@ def _load_remote_dataset( raise GMTInvalidInput(msg) fname = f"@{prefix}_{resolution}_{reg}" - kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[dataset.kind]} - with Session() as lib: - with lib.virtualfile_out(kind=dataset.kind) as voutgrd: - lib.call_module( - module="read", - args=[fname, voutgrd, *build_arg_list(kwdict)], - ) - grid = lib.virtualfile_to_raster( - kind=dataset.kind, outgrid=None, vfname=voutgrd - ) - - # Full path to the grid if not tiled grids. - source = which(fname, download="a") if not resinfo.tiled else None - # Manually add source to xarray.DataArray encoding to make the GMT accessors work. - if source: - grid.encoding["source"] = source + grid = xr.load_dataarray( + fname, engine="gmt", raster_kind=dataset.kind, region=region + ) # Add some metadata to the grid grid.attrs["description"] = dataset.description diff --git a/pygmt/tests/test_xarray_accessor.py b/pygmt/tests/test_xarray_accessor.py index 11b184110f6..bd42162d6b2 100644 --- a/pygmt/tests/test_xarray_accessor.py +++ b/pygmt/tests/test_xarray_accessor.py @@ -132,10 +132,13 @@ def test_xarray_accessor_sliced_datacube(): Path(fname).unlink() -def test_xarray_accessor_grid_source_file_not_exist(): +def test_xarray_accessor_tiled_grid_slice_and_add(): """ - Check that the accessor fallbacks to the default registration and gtype when the - grid source file (i.e., grid.encoding["source"]) doesn't exist. + Check that the accessor works to get the registration and gtype when the grid source + file is from a tiled grid, that slicing doesn't affect registration/gtype, but math + operations do return the default registration/gtype as a fallback. + + Unit test to track https://github.com/GenericMappingTools/pygmt/issues/524 """ # Load the 05m earth relief grid, which is stored as tiles. grid = load_earth_relief( @@ -144,17 +147,25 @@ def test_xarray_accessor_grid_source_file_not_exist(): # Registration and gtype are correct. assert grid.gmt.registration is GridRegistration.PIXEL assert grid.gmt.gtype is GridType.GEOGRAPHIC - # The source grid file is undefined. - assert grid.encoding.get("source") is None + # The source grid file for tiled grids is the first tile + assert grid.encoding["source"].endswith("S90W180.earth_relief_05m_p.nc") - # For a sliced grid, fallback to default registration and gtype, because the source - # grid file doesn't exist. + # For a sliced grid, ensure we don't fallback to the default registration (gridline) + # and gtype (cartesian), because the source grid file should still exist. sliced_grid = grid[1:3, 1:3] - assert sliced_grid.gmt.registration is GridRegistration.GRIDLINE - assert sliced_grid.gmt.gtype is GridType.CARTESIAN - - # Still possible to manually set registration and gtype. - sliced_grid.gmt.registration = GridRegistration.PIXEL - sliced_grid.gmt.gtype = GridType.GEOGRAPHIC + assert sliced_grid.encoding["source"].endswith("S90W180.earth_relief_05m_p.nc") assert sliced_grid.gmt.registration is GridRegistration.PIXEL assert sliced_grid.gmt.gtype is GridType.GEOGRAPHIC + + # For a grid that underwent mathematical operations, fallback to default + # registration and gtype, because the source grid file doesn't exist. + added_grid = sliced_grid + 9 + assert added_grid.encoding == {} + assert added_grid.gmt.registration is GridRegistration.GRIDLINE + assert added_grid.gmt.gtype is GridType.CARTESIAN + + # Still possible to manually set registration and gtype. + added_grid.gmt.registration = GridRegistration.PIXEL + added_grid.gmt.gtype = GridType.GEOGRAPHIC + assert added_grid.gmt.registration is GridRegistration.PIXEL + assert added_grid.gmt.gtype is GridType.GEOGRAPHIC diff --git a/pygmt/tests/test_xarray_backend.py b/pygmt/tests/test_xarray_backend.py index 4704bea9bf0..5a373926edf 100644 --- a/pygmt/tests/test_xarray_backend.py +++ b/pygmt/tests/test_xarray_backend.py @@ -40,8 +40,8 @@ def test_xarray_backend_load_dataarray(): def test_xarray_backend_gmt_open_nc_grid(): """ - Ensure that passing engine='gmt' to xarray.open_dataarray works for opening NetCDF - grids. + Ensure that passing engine='gmt' to xarray.open_dataarray works to open a netCDF + grid. """ with xr.open_dataarray( "@static_earth_relief.nc", engine="gmt", raster_kind="grid" @@ -52,10 +52,29 @@ def test_xarray_backend_gmt_open_nc_grid(): assert da.gmt.registration is GridRegistration.PIXEL +def test_xarray_backend_gmt_open_nc_grid_with_region_bbox(): + """ + Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray + works to open a netCDF grid over a specific bounding box. + """ + with xr.open_dataarray( + "@static_earth_relief.nc", + engine="gmt", + raster_kind="grid", + region=[-52, -48, -18, -12], + ) as da: + assert da.sizes == {"lat": 6, "lon": 4} + npt.assert_allclose(da.lat, [-17.5, -16.5, -15.5, -14.5, -13.5, -12.5]) + npt.assert_allclose(da.lon, [-51.5, -50.5, -49.5, -48.5]) + assert da.dtype == "float32" + assert da.gmt.gtype is GridType.GEOGRAPHIC + assert da.gmt.registration is GridRegistration.PIXEL + + def test_xarray_backend_gmt_open_tif_image(): """ - Ensure that passing engine='gmt' to xarray.open_dataarray works for opening GeoTIFF - images. + Ensure that passing engine='gmt' to xarray.open_dataarray works to open a GeoTIFF + image. """ with xr.open_dataarray("@earth_day_01d", engine="gmt", raster_kind="image") as da: assert da.sizes == {"band": 3, "y": 180, "x": 360} @@ -64,6 +83,22 @@ def test_xarray_backend_gmt_open_tif_image(): assert da.gmt.registration is GridRegistration.PIXEL +def test_xarray_backend_gmt_open_tif_image_with_region_iso(): + """ + Ensure that passing engine='gmt' with a `region` argument to xarray.open_dataarray + works to open a GeoTIFF image over a specific ISO country code border. + """ + with xr.open_dataarray( + "@earth_day_01d", engine="gmt", raster_kind="image", region="BN" + ) as da: + assert da.sizes == {"band": 3, "lat": 2, "lon": 2} + npt.assert_allclose(da.lat, [5.5, 4.5]) + npt.assert_allclose(da.lon, [114.5, 115.5]) + assert da.dtype == "uint8" + assert da.gmt.gtype is GridType.GEOGRAPHIC + assert da.gmt.registration is GridRegistration.PIXEL + + def test_xarray_backend_gmt_load_grd_grid(): """ Ensure that passing engine='gmt' to xarray.load_dataarray works for loading GRD @@ -88,9 +123,7 @@ def test_xarray_backend_gmt_read_invalid_kind(): """ with pytest.raises( TypeError, - match=re.escape( - "GMTBackendEntrypoint.open_dataset() missing 1 required keyword-only argument: 'raster_kind'" - ), + match=re.escape("missing a required argument: 'raster_kind'"), ): xr.open_dataarray("nokind.nc", engine="gmt") diff --git a/pygmt/xarray/backend.py b/pygmt/xarray/backend.py index a95e98983db..25adf5168f4 100644 --- a/pygmt/xarray/backend.py +++ b/pygmt/xarray/backend.py @@ -2,13 +2,14 @@ An xarray backend for reading raster grid/image files using the 'gmt' engine. """ +from collections.abc import Sequence from typing import Literal import xarray as xr from pygmt._typing import PathLike from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list +from pygmt.helpers import build_arg_list, kwargs_to_strings from pygmt.src.which import which from xarray.backends import BackendEntrypoint @@ -30,6 +31,9 @@ class GMTBackendEntrypoint(BackendEntrypoint): - ``"grid"``: for reading single-band raster grids - ``"image"``: for reading multi-band raster images + Optionally, you can pass in a `region`in the form of a sequence [*xmin*, *xmax*, + *ymin*, *ymax*] or an ISO country code. + Examples -------- Read a single-band netCDF file using ``raster_kind="grid"`` @@ -68,18 +72,45 @@ class GMTBackendEntrypoint(BackendEntrypoint): * band (band) uint8... 1 2 3 Attributes:... long_name: z + + Load a single-band netCDF file using ``raster_kind="grid"`` over a bounding box + region. + + >>> da_grid = xr.load_dataarray( + ... "@tut_bathy.nc", engine="gmt", raster_kind="grid", region=[-64, -62, 32, 33] + ... ) + >>> da_grid + ... + array([[-4369., -4587., -4469., -4409., -4587., -4505., -4403., -4405., + -4466., -4595., -4609., -4608., -4606., -4607., -4607., -4597., + ... + -4667., -4642., -4677., -4795., -4797., -4800., -4803., -4818., + -4820.]], dtype=float32) + Coordinates: + * lat (lat) float64... 32.0 32.08 32.17 32.25 ... 32.83 32.92 33.0 + * lon (lon) float64... -64.0 -63.92 -63.83 ... -62.17 -62.08 -62.0 + Attributes:... + Conventions: CF-1.7 + title: ETOPO5 global topography + history: grdreformat -fg bermuda.grd bermuda.nc=ns + description: /home/elepaio5/data/grids/etopo5.i2 + actual_range: [-4968. -4315.] + long_name: Topography + units: m """ description = "Open raster (.grd, .nc or .tif) files in Xarray via GMT." - open_dataset_parameters = ("filename_or_obj", "raster_kind") + open_dataset_parameters = ("filename_or_obj", "raster_kind", "region") url = "https://pygmt.org/dev/api/generated/pygmt.GMTBackendEntrypoint.html" + @kwargs_to_strings(region="sequence") def open_dataset( # type: ignore[override] self, filename_or_obj: PathLike, *, drop_variables=None, # noqa: ARG002 raster_kind: Literal["grid", "image"], + region: Sequence[float] | str | None = None, # other backend specific keyword arguments # `chunks` and `cache` DO NOT go here, they are handled by xarray ) -> xr.Dataset: @@ -94,6 +125,9 @@ def open_dataset( # type: ignore[override] :gmt-docs:`reference/features.html#grid-file-format`. raster_kind Whether to read the file as a "grid" (single-band) or "image" (multi-band). + region + Optional. The subregion of the grid or image to load, in the form of a + sequence [*xmin*, *xmax*, *ymin*, *ymax*] or an ISO country code. """ if raster_kind not in {"grid", "image"}: msg = f"Invalid raster kind: '{raster_kind}'. Valid values are 'grid' or 'image'." @@ -101,7 +135,7 @@ def open_dataset( # type: ignore[override] with Session() as lib: with lib.virtualfile_out(kind=raster_kind) as voutfile: - kwdict = {"T": {"grid": "g", "image": "i"}[raster_kind]} + kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[raster_kind]} lib.call_module( module="read", args=[filename_or_obj, voutfile, *build_arg_list(kwdict)], @@ -111,9 +145,8 @@ def open_dataset( # type: ignore[override] vfname=voutfile, kind=raster_kind ) # Add "source" encoding - source = which(fname=filename_or_obj) + source: str | list = which(fname=filename_or_obj, verbose="q") raster.encoding["source"] = ( source[0] if isinstance(source, list) else source ) - _ = raster.gmt # Load GMTDataArray accessor information return raster.to_dataset()