Skip to content

Commit

Permalink
Feature/support multiple lams to the Cutout class (ecmwf#113)
Browse files Browse the repository at this point in the history
* Enhance Cutout class to support multiple LAMs with hierarchical masking.

---------

Co-authored-by: Paulina Met. <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent 32a2fa9 commit e1ab0b8
Show file tree
Hide file tree
Showing 7 changed files with 1,484 additions and 62 deletions.
280 changes: 218 additions & 62 deletions src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from functools import cached_property

import numpy as np
from scipy.spatial import cKDTree

from .debug import Node
from .debug import debug_indexing
Expand Down Expand Up @@ -142,95 +143,250 @@ def tree(self):


class Cutout(GridsBase):
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
from anemoi.datasets.grids import cutout_mask

def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None):
"""Initializes a Cutout object for hierarchical management of Limited Area
Models (LAMs) and a global dataset, handling overlapping regions.
Args:
datasets (list): List of LAM and global datasets.
axis (int): Concatenation axis, must be set to 3.
cropping_distance (float): Distance threshold in degrees for
cropping cutouts.
neighbours (int): Number of neighboring points to consider when
constructing masks.
min_distance_km (float, optional): Minimum distance threshold in km
between grid points.
plot (bool, optional): Flag to enable or disable visualization
plots.
"""
super().__init__(datasets, axis)
assert len(datasets) == 2, "CutoutGrids requires two datasets"
assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
assert axis == 3, "CutoutGrids requires axis=3"
assert cropping_distance >= 0, "cropping_distance must be a non-negative number"
if min_distance_km is not None:
assert min_distance_km >= 0, "min_distance_km must be a non-negative number"

self.lams = datasets[:-1] # Assume the last dataset is the global one
self.globe = datasets[-1]
self.axis = axis
self.cropping_distance = cropping_distance
self.neighbours = neighbours
self.min_distance_km = min_distance_km
self.plot = plot
self.masks = [] # To store the masks for each LAM dataset
self.global_mask = np.ones(self.globe.shape[-1], dtype=bool)

# Initialize cumulative masks
self._initialize_masks()

def _initialize_masks(self):
"""Generates hierarchical masks for each LAM dataset by excluding
overlapping regions with previous LAMs and creating a global mask for
the global dataset.
Raises:
ValueError: If the global mask dimension does not match the global
dataset grid points.
"""
from anemoi.datasets.grids import cutout_mask

# We assume that the LAM is the first dataset, and the global is the second
# Note: the second fields does not really need to be global

self.lam, self.globe = datasets
self.mask = cutout_mask(
self.lam.latitudes,
self.lam.longitudes,
self.globe.latitudes,
self.globe.longitudes,
plot=plot,
min_distance_km=min_distance_km,
cropping_distance=cropping_distance,
neighbours=neighbours,
)
assert len(self.mask) == self.globe.shape[3], (
len(self.mask),
self.globe.shape[3],
)
for i, lam in enumerate(self.lams):
assert len(lam.shape) == len(
self.globe.shape
), "LAMs and global dataset must have the same number of dimensions"
lam_lats = lam.latitudes
lam_lons = lam.longitudes
# Create a mask for the global dataset excluding all LAM points
global_overlap_mask = cutout_mask(
lam.latitudes,
lam.longitudes,
self.globe.latitudes,
self.globe.longitudes,
plot=False,
min_distance_km=self.min_distance_km,
cropping_distance=self.cropping_distance,
neighbours=self.neighbours,
)

# Ensure the mask dimensions match the global grid points
if global_overlap_mask.shape[0] != self.globe.shape[-1]:
raise ValueError("Global mask dimension does not match global dataset grid " "points.")
self.global_mask[~global_overlap_mask] = False

# Create a mask for the LAM datasets hierarchically, excluding
# points from previous LAMs
lam_current_mask = np.ones(lam.shape[-1], dtype=bool)
if i > 0:
for j in range(i):
prev_lam = self.lams[j]
prev_lam_lats = prev_lam.latitudes
prev_lam_lons = prev_lam.longitudes
# Check for overlap by computing distances
if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons):
lam_overlap_mask = cutout_mask(
prev_lam_lats,
prev_lam_lons,
lam_lats,
lam_lons,
plot=False,
min_distance_km=self.min_distance_km,
cropping_distance=self.cropping_distance,
neighbours=self.neighbours,
)
lam_current_mask[~lam_overlap_mask] = False
self.masks.append(lam_current_mask)

def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
"""Checks for overlapping points between two sets of latitudes and
longitudes within a specified distance threshold.
Args:
lats1, lons1 (np.ndarray): Latitude and longitude arrays for the
first dataset.
lats2, lons2 (np.ndarray): Latitude and longitude arrays for the
second dataset.
distance_threshold (float): Distance in degrees to consider as
overlapping.
Returns:
bool: True if any points overlap within the distance threshold,
otherwise False.
"""
# Create KDTree for the first set of points
tree = cKDTree(np.vstack((lats1, lons1)).T)

# Query the second set of points against the first tree
distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1)

# Check if any distance is less than the specified threshold
return np.any(distances < distance_threshold)

def __getitem__(self, index):
"""Retrieves data from the masked LAMs and global dataset based on the
given index.
Args:
index (int or slice or tuple): Index specifying the data to
retrieve.
Returns:
np.ndarray: Data array from the masked datasets based on the index.
"""
if isinstance(index, (int, slice)):
index = (index, slice(None), slice(None), slice(None))
return self._get_tuple(index)

def _get_tuple(self, index):
"""Helper method that applies masks and retrieves data from each dataset
according to the specified index.
Args:
index (tuple): Index specifying slices to retrieve data.
Returns:
np.ndarray: Concatenated data array from all datasets based on the
index.
"""
index, changes = index_to_slices(index, self.shape)
# Select data from each LAM
lam_data = [lam[index] for lam in self.lams]

# First apply spatial indexing on `self.globe` and then apply the mask
globe_data_sliced = self.globe[index[:3]]
globe_data = globe_data_sliced[..., self.global_mask]

# Concatenate LAM data with global data
result = np.concatenate(lam_data + [globe_data], axis=self.axis)
return apply_index_to_slices_changes(result, changes)

def collect_supporting_arrays(self, collected, *path):
collected.append((path, "cutout_mask", self.mask))
"""Collects supporting arrays, including masks for each LAM and the global
dataset.
Args:
collected (list): List to which the supporting arrays are appended.
*path: Variable length argument list specifying the paths for the masks.
"""
# Append masks for each LAM
for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
collected.append((path + (f"lam_{i}",), "cutout_mask", mask))

# Append the global mask
collected.append((path + ("global",), "cutout_mask", self.global_mask))

@cached_property
def shape(self):
shape = self.lam.shape
# Number of non-zero masked values in the globe dataset
nb_globe = np.count_nonzero(self.mask)
return shape[:-1] + (shape[-1] + nb_globe,)
"""Returns the shape of the Cutout, accounting for retained grid points
across all LAMs and the global dataset.
Returns:
tuple: Shape of the concatenated masked datasets.
"""
shapes = [np.sum(mask) for mask in self.masks]
global_shape = np.sum(self.global_mask)
return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,))

def check_same_resolution(self, d1, d2):
# Turned off because we are combining different resolutions
pass

@property
def latitudes(self):
return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]])
def grids(self):
"""Returns the number of grid points for each LAM and the global dataset
after applying masks.
@property
def longitudes(self):
return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]])
Returns:
tuple: Count of retained grid points for each dataset.
"""
grids = [np.sum(mask) for mask in self.masks]
grids.append(np.sum(self.global_mask))
return tuple(grids)

def __getitem__(self, index):
if isinstance(index, (int, slice)):
index = (index, slice(None), slice(None), slice(None))
return self._get_tuple(index)
@property
def latitudes(self):
"""Returns the concatenated latitudes of each LAM and the global dataset
after applying masks.
@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
assert self.axis >= len(index) or index[self.axis] == slice(
None
), f"No support for selecting a subset of the 1D values {index} ({self.tree()})"
index, changes = index_to_slices(index, self.shape)
Returns:
np.ndarray: Concatenated latitude array for the masked datasets.
"""
lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])

# In case index_to_slices has changed the last slice
index, _ = update_tuple(index, self.axis, slice(None))
assert (
len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1]
), "Mismatch in number of latitudes"

lam_data = self.lam[index]
globe_data = self.globe[index]
latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]])
return latitudes

globe_data = globe_data[:, :, :, self.mask]
@property
def longitudes(self):
"""Returns the concatenated longitudes of each LAM and the global dataset
after applying masks.
result = np.concatenate([lam_data, globe_data], axis=self.axis)
Returns:
np.ndarray: Concatenated longitude array for the masked datasets.
"""
lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])

return apply_index_to_slices_changes(result, changes)
assert (
len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1]
), "Mismatch in number of longitudes"

@property
def grids(self):
for d in self.datasets:
if len(d.grids) > 1:
raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs")
shape = self.lam.shape
return (shape[-1], self.shape[-1] - shape[-1])
longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
return longitudes

def tree(self):
"""Generates a hierarchical tree structure for the `Cutout` instance and
its associated datasets.
Returns:
Node: A `Node` object representing the `Cutout` instance as the root
node, with each dataset in `self.datasets` represented as a child
node.
"""
return Node(self, [d.tree() for d in self.datasets])

# def metadata_specific(self):
# return super().metadata_specific(
# mask=serialise_mask(self.mask),
# )


def grids_factory(args, kwargs):
if "ensemble" in kwargs:
Expand Down
42 changes: 42 additions & 0 deletions tools/grids/grids3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
common:
mars_request: &mars_request
expver: "0001"
grid: 0.25/0.25
area: [40, 25, 20, 60]
rotation: [-20, -40]

dates:
start: 2024-01-01 00:00:00
end: 2024-01-01 18:00:00
frequency: 6h

input:
join:
- mars:
<<: *mars_request
param: [2t, 10u, 10v, lsm]
levtype: sfc
stream: oper
type: an
- mars:
<<: *mars_request
param: [q, t, z]
levtype: pl
level: [50, 100]
stream: oper
type: an
- accumulations:
<<: *mars_request
levtype: sfc
param: [cp, tp]
- forcings:
template: ${input.join.0.mars}
param:
- cos_latitude
- sin_latitude

output:
order_by: [valid_datetime, param_level, number]
remapping:
param_level: "{param}_{levelist}"
statistics: param_level
Loading

0 comments on commit e1ab0b8

Please sign in to comment.