diff --git a/pygmt/src/grdcut.py b/pygmt/src/grdcut.py index d248d69ae27..21876c60d6a 100644 --- a/pygmt/src/grdcut.py +++ b/pygmt/src/grdcut.py @@ -2,23 +2,24 @@ grdcut - Extract subregion from a grid. """ +from typing import Literal + import xarray as xr from pygmt.clib import Session +from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_list, + data_kind, fmt_docstring, kwargs_to_strings, use_alias, ) -from pygmt.io import load_dataarray __doctest_skip__ = ["grdcut"] @fmt_docstring @use_alias( - G="outgrid", R="region", J="projection", N="extend", @@ -28,9 +29,11 @@ f="coltypes", ) @kwargs_to_strings(R="sequence") -def grdcut(grid, **kwargs) -> xr.DataArray | None: +def grdcut( + grid, kind: Literal["grid", "image"] = "grid", outgrid: str | None = None, **kwargs +) -> xr.DataArray | None: r""" - Extract subregion from a grid. + Extract subregion from a grid or image. Produce a new ``outgrid`` file which is a subregion of ``grid``. The subregion is specified with ``region``; the specified range must not exceed @@ -48,6 +51,11 @@ def grdcut(grid, **kwargs) -> xr.DataArray | None: Parameters ---------- {grid} + kind + The raster data kind. Valid values are ``"grid"`` and ``"image"``. When the + input ``grid`` is a file name, it's difficult to determine if the file is a grid + or an image, so we need to specify the raster kind explicitly. The default is + ``"grid"``. {outgrid} {projection} {region} @@ -100,13 +108,27 @@ def grdcut(grid, **kwargs) -> xr.DataArray | None: >>> # 12° E to 15° E and a latitude range of 21° N to 24° N >>> new_grid = pygmt.grdcut(grid=grid, region=[12, 15, 21, 24]) """ - with GMTTempFile(suffix=".nc") as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd: - if (outgrid := kwargs.get("G")) is None: - kwargs["G"] = outgrid = tmpfile.name # output to tmpfile - lib.call_module( - module="grdcut", args=build_arg_list(kwargs, infile=vingrd) - ) + if kind not in {"grid", "image"}: + msg = f"Invalid raster kind: '{kind}'. Valid values are 'grid' and 'image'." + raise GMTInvalidInput(msg) + + # Determine the output data kind based on the input data kind. + match inkind := data_kind(grid): + case "grid" | "image": + outkind = inkind + case "file": + outkind = kind + case _: + msg = f"Unsupported data type {type(grid)}." + raise GMTInvalidInput(msg) - return load_dataarray(outgrid) if outgrid == tmpfile.name else None + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_out(kind=outkind, fname=outgrid) as voutgrd, + ): + kwargs["G"] = voutgrd + lib.call_module(module="grdcut", args=build_arg_list(kwargs, infile=vingrd)) + return lib.virtualfile_to_raster( + vfname=voutgrd, kind=outkind, outgrid=outgrid + ) diff --git a/pygmt/tests/test_grdcut.py b/pygmt/tests/test_grdcut.py index dbf5dd21f49..5dccb06fd52 100644 --- a/pygmt/tests/test_grdcut.py +++ b/pygmt/tests/test_grdcut.py @@ -70,3 +70,11 @@ def test_grdcut_fails(): """ with pytest.raises(GMTInvalidInput): grdcut(np.arange(10).reshape((5, 2))) + + +def test_grdcut_invalid_kind(grid, region): + """ + Check that grdcut fails with incorrect 'kind'. + """ + with pytest.raises(GMTInvalidInput): + grdcut(grid, kind="invalid", region=region) diff --git a/pygmt/tests/test_grdcut_image.py b/pygmt/tests/test_grdcut_image.py new file mode 100644 index 00000000000..585a7b1b4a5 --- /dev/null +++ b/pygmt/tests/test_grdcut_image.py @@ -0,0 +1,88 @@ +""" +Test pygmt.grdcut on images. +""" + +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr +from pygmt import grdcut +from pygmt.datasets import load_blue_marble +from pygmt.helpers import GMTTempFile + +try: + import rioxarray + + _HAS_RIOXARRAY = True +except ImportError: + _HAS_RIOXARRAY = False + + +@pytest.fixture(scope="module", name="region") +def fixture_region(): + """ + Set the data region. + """ + return [-53, -49, -20, -17] + + +@pytest.fixture(scope="module", name="expected_image") +def fixture_expected_image(): + """ + Load the expected grdcut image result. + """ + return xr.DataArray( + data=np.array( + [ + [[90, 93, 95, 90], [91, 90, 91, 91], [91, 90, 89, 90]], + [[87, 88, 88, 89], [88, 87, 86, 85], [90, 90, 89, 88]], + [[48, 49, 49, 45], [49, 48, 47, 45], [48, 47, 48, 46]], + ], + dtype=np.uint8, + ), + coords={ + "band": [1, 2, 3], + "x": [-52.5, -51.5, -50.5, -49.5], + "y": [-17.5, -18.5, -19.5], + }, + dims=["band", "y", "x"], + attrs={ + "scale_factor": 1.0, + "add_offset": 0.0, + }, + ) + + +@pytest.mark.benchmark +def test_grdcut_image_file(region, expected_image): + """ + Test grdcut on an input image file. + """ + result = grdcut("@earth_day_01d", region=region, kind="image") + xr.testing.assert_allclose(a=result, b=expected_image) + + +@pytest.mark.benchmark +@pytest.mark.skipif(not _HAS_RIOXARRAY, reason="rioxarray is not installed") +def test_grdcut_image_dataarray(region, expected_image): + """ + Test grdcut on an input xarray.DataArray object. + """ + raster = load_blue_marble() + result = grdcut(raster, region=region, kind="image") + xr.testing.assert_allclose(a=result, b=expected_image) + + +def test_grdcut_image_file_in_file_out(region, expected_image): + """ + Test grdcut on an input image file and outputs to another image file. + """ + with GMTTempFile(suffix=".tif") as tmp: + result = grdcut("@earth_day_01d", region=region, outgrid=tmp.name) + assert result is None + assert Path(tmp.name).stat().st_size > 0 + if _HAS_RIOXARRAY: + with rioxarray.open_rasterio(tmp.name) as raster: + image = raster.load().drop_vars("spatial_ref") + xr.testing.assert_allclose(a=image, b=expected_image)