Skip to content

Commit 8b2a74c

Browse files
authored
Refactor the _load_remote_dataset function to load tiled and non-tiled grids in a consistent way (#3120)
1 parent 44f44d3 commit 8b2a74c

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

pygmt/datasets/load_remote_dataset.py

+40-26
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING, ClassVar, NamedTuple
7+
from typing import TYPE_CHECKING, ClassVar, Literal, NamedTuple
88

9+
from pygmt.clib import Session
910
from pygmt.exceptions import GMTInvalidInput
10-
from pygmt.helpers import kwargs_to_strings
11-
from pygmt.io import load_dataarray
12-
from pygmt.src import grdcut, which
11+
from pygmt.helpers import build_arg_list, kwargs_to_strings
12+
from pygmt.src import which
1313

1414
if TYPE_CHECKING:
1515
import xarray as xr
@@ -344,7 +344,7 @@ def _load_remote_dataset(
344344
dataset_prefix: str,
345345
resolution: str,
346346
region: str | list,
347-
registration: str,
347+
registration: Literal["gridline", "pixel", None],
348348
) -> xr.DataArray:
349349
r"""
350350
Load GMT remote datasets.
@@ -370,54 +370,68 @@ def _load_remote_dataset(
370370
371371
Returns
372372
-------
373-
grid : :class:`xarray.DataArray`
373+
grid
374374
The GMT remote dataset grid.
375375
376376
Note
377377
----
378-
The returned :class:`xarray.DataArray` doesn't support slice operation for tiled
379-
grids.
378+
The registration and coordinate system type of the returned
379+
:class:`xarray.DataArray` grid can be accessed via the GMT accessors (i.e.,
380+
``grid.gmt.registration`` and ``grid.gmt.gtype`` respectively). However, these
381+
properties may be lost after specific grid operations (such as slicing) and will
382+
need to be manually set before passing the grid to any PyGMT data processing or
383+
plotting functions. Refer to :class:`pygmt.GMTDataArrayAccessor` for detailed
384+
explanations and workarounds.
380385
"""
381386
dataset = datasets[dataset_name]
382387

388+
# Check resolution
383389
if resolution not in dataset.resolutions:
384390
raise GMTInvalidInput(
385391
f"Invalid resolution '{resolution}' for {dataset.title} dataset. "
386392
f"Available resolutions are: {', '.join(dataset.resolutions)}."
387393
)
394+
resinfo = dataset.resolutions[resolution]
388395

389-
# check registration
390-
valid_registrations = dataset.resolutions[resolution].registrations
396+
# Check registration
391397
if registration is None:
392-
# use gridline registration unless only pixel registration is available
393-
registration = "gridline" if "gridline" in valid_registrations else "pixel"
398+
# Use gridline registration unless only pixel registration is available
399+
registration = "gridline" if "gridline" in resinfo.registrations else "pixel"
394400
elif registration in ("pixel", "gridline"):
395-
if registration not in valid_registrations:
401+
if registration not in resinfo.registrations:
396402
raise GMTInvalidInput(
397403
f"{registration} registration is not available for the "
398404
f"{resolution} {dataset.title} dataset. Only "
399-
f"{valid_registrations[0]} registration is available."
405+
f"{resinfo.registrations[0]} registration is available."
400406
)
401407
else:
402408
raise GMTInvalidInput(
403409
f"Invalid grid registration: '{registration}', should be either 'pixel', "
404410
"'gridline' or None. Default is None, where a gridline-registered grid is "
405411
"returned unless only the pixel-registered grid is available."
406412
)
407-
reg = f"_{registration[0]}"
408413

409-
# different ways to load tiled and non-tiled grids.
410-
# Known issue: tiled grids don't support slice operation
411-
# See https://github.com/GenericMappingTools/pygmt/issues/524
412-
if region is None:
413-
if dataset.resolutions[resolution].tiled:
414-
raise GMTInvalidInput(
415-
f"'region' is required for {dataset.title} resolution '{resolution}'."
414+
fname = f"@{dataset_prefix}{resolution}_{registration[0]}"
415+
if resinfo.tiled and region is None:
416+
raise GMTInvalidInput(
417+
f"'region' is required for {dataset.title} resolution '{resolution}'."
418+
)
419+
420+
# Currently, only grids are supported. Will support images in the future.
421+
kwdict = {"T": "g", "R": region} # region can be None
422+
with Session() as lib:
423+
with lib.virtualfile_out(kind="grid") as voutgrd:
424+
lib.call_module(
425+
module="read",
426+
args=[fname, voutgrd, *build_arg_list(kwdict)],
416427
)
417-
fname = which(f"@{dataset_prefix}{resolution}{reg}", download="a")
418-
grid = load_dataarray(fname, engine="netcdf4")
419-
else:
420-
grid = grdcut(f"@{dataset_prefix}{resolution}{reg}", region=region)
428+
grid = lib.virtualfile_to_raster(outgrid=None, vfname=voutgrd)
429+
430+
# Full path to the grid if not tiled grids.
431+
source = which(fname, download="a") if not resinfo.tiled else None
432+
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
433+
if source:
434+
grid.encoding["source"] = source
421435

422436
# Add some metadata to the grid
423437
grid.name = dataset.name

pygmt/tests/test_accessor.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ def test_accessor_grid_source_file_not_exist():
115115
# Registration and gtype are correct
116116
assert grid.gmt.registration == 1
117117
assert grid.gmt.gtype == 1
118-
# The source grid file is defined but doesn't exist
119-
assert grid.encoding["source"].endswith(".nc")
120-
assert not Path(grid.encoding["source"]).exists()
118+
# The source grid file is undefined.
119+
assert grid.encoding.get("source") is None
121120

122121
# For a sliced grid, fallback to default registration and gtype,
123122
# because the source grid file doesn't exist.

0 commit comments

Comments
 (0)